aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-01-09 12:04:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-09 12:26:35 -0800
commit1e67c90e2caceeff82d09793d1ef5fa0300d219b (patch)
tree6567ea8b0fa01fcfcd608b7e4c636865d33c7032
parent7ad7e4dfae4344d6b955b5eb61dc4b6bb792f1b3 (diff)
Initial open-source release of XLA: Accelerated Linear Algebra.
XLA is a compiler-based linear algebra execution engine that targets CPUs, GPUs and custom accelerators. XLA is still experimental; we are releasing it early to get the community involved. Change: 143990941
-rwxr-xr-xconfigure20
-rw-r--r--tensorflow/BUILD20
-rw-r--r--tensorflow/compiler/aot/BUILD218
-rw-r--r--tensorflow/compiler/aot/benchmark.cc138
-rw-r--r--tensorflow/compiler/aot/benchmark.h70
-rw-r--r--tensorflow/compiler/aot/benchmark_main.template51
-rw-r--r--tensorflow/compiler/aot/benchmark_test.cc46
-rw-r--r--tensorflow/compiler/aot/codegen.cc579
-rw-r--r--tensorflow/compiler/aot/codegen.h53
-rw-r--r--tensorflow/compiler/aot/codegen_test.cc137
-rw-r--r--tensorflow/compiler/aot/codegen_test_h.golden268
-rw-r--r--tensorflow/compiler/aot/compile.cc416
-rw-r--r--tensorflow/compiler/aot/compile.h92
-rw-r--r--tensorflow/compiler/aot/flags.cc72
-rw-r--r--tensorflow/compiler/aot/flags.h48
-rw-r--r--tensorflow/compiler/aot/runtime.cc98
-rw-r--r--tensorflow/compiler/aot/runtime.h58
-rw-r--r--tensorflow/compiler/aot/runtime_test.cc125
-rw-r--r--tensorflow/compiler/aot/test.cc94
-rw-r--r--tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt16
-rw-r--r--tensorflow/compiler/aot/test_graph_tfadd.pbtxt63
-rw-r--r--tensorflow/compiler/aot/tests/BUILD146
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py119
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt16
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt10
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt16
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt18
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt25
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc381
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl285
-rw-r--r--tensorflow/compiler/aot/tfcompile.proto43
-rw-r--r--tensorflow/compiler/aot/tfcompile_main.cc142
-rw-r--r--tensorflow/compiler/aot/tfcompile_util.cc119
-rw-r--r--tensorflow/compiler/aot/tfcompile_util.h36
-rw-r--r--tensorflow/compiler/aot/tfcompile_util_test.cc185
-rw-r--r--tensorflow/compiler/jit/BUILD282
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.cc215
-rw-r--r--tensorflow/compiler/jit/build_xla_launch_ops_pass.h31
-rw-r--r--tensorflow/compiler/jit/defs.cc22
-rw-r--r--tensorflow/compiler/jit/defs.h29
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc660
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.h86
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc397
-rw-r--r--tensorflow/compiler/jit/graph_to_functiondef.cc274
-rw-r--r--tensorflow/compiler/jit/graph_to_functiondef.h33
-rw-r--r--tensorflow/compiler/jit/graph_to_functiondef_test.cc87
-rw-r--r--tensorflow/compiler/jit/graphcycles/BUILD41
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles.cc391
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles.h128
-rw-r--r--tensorflow/compiler/jit/graphcycles/graphcycles_test.cc515
-rw-r--r--tensorflow/compiler/jit/jit_compilation_pass_registration.cc37
-rw-r--r--tensorflow/compiler/jit/legacy_flags/BUILD67
-rw-r--r--tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc63
-rw-r--r--tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h50
-rw-r--r--tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc76
-rw-r--r--tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h59
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc68
-rw-r--r--tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h52
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc534
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.h55
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc357
-rw-r--r--tensorflow/compiler/jit/parallel_check_op.cc154
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc199
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h112
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc60
-rw-r--r--tensorflow/compiler/jit/xla_device.cc219
-rw-r--r--tensorflow/compiler/jit/xla_device.h120
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc181
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h92
-rw-r--r--tensorflow/compiler/jit/xla_device_launch_op.cc171
-rw-r--r--tensorflow/compiler/jit/xla_device_launch_op.h50
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.cc36
-rw-r--r--tensorflow/compiler/jit/xla_device_ops.h118
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc65
-rw-r--r--tensorflow/compiler/jit/xla_local_launch_op.cc342
-rw-r--r--tensorflow/compiler/jit/xla_local_launch_op.h55
-rw-r--r--tensorflow/compiler/tests/BUILD352
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py749
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl78
-rw-r--r--tensorflow/compiler/tests/clustering_test.py102
-rw-r--r--tensorflow/compiler/tests/concat_ops_test.py374
-rw-r--r--tensorflow/compiler/tests/conv2d_test.py526
-rw-r--r--tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc30
-rw-r--r--tensorflow/compiler/tests/dynamic_stitch_test.py86
-rw-r--r--tensorflow/compiler/tests/function_test.py130
-rw-r--r--tensorflow/compiler/tests/jit_test.py459
-rw-r--r--tensorflow/compiler/tests/lrn_ops_test.py129
-rw-r--r--tensorflow/compiler/tests/lstm.py158
-rw-r--r--tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt20
-rw-r--r--tensorflow/compiler/tests/lstm_layer_inference.pbtxt5828
-rw-r--r--tensorflow/compiler/tests/lstm_test.py293
-rw-r--r--tensorflow/compiler/tests/nary_ops_test.py209
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py61
-rw-r--r--tensorflow/compiler/tests/pooling_ops_test.py511
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc2097
-rw-r--r--tensorflow/compiler/tests/reduce_ops_test.py125
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py110
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py346
-rw-r--r--tensorflow/compiler/tests/xla_device_test.py81
-rw-r--r--tensorflow/compiler/tests/xla_test.py148
-rw-r--r--tensorflow/compiler/tf2xla/BUILD193
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc139
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.h33
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis_test.cc83
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph.cc78
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph.h50
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph_flags.cc63
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph_flags.h48
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD177
-rw-r--r--tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc47
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc141
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bcast_ops.cc87
-rw-r--r--tensorflow/compiler/tf2xla/kernels/bias_ops.cc119
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc158
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cast_op.cc71
-rw-r--r--tensorflow/compiler/tf2xla/kernels/concat_op.cc210
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc373
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc177
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h109
-rw-r--r--tensorflow/compiler/tf2xla/kernels/declaration_op.cc127
-rw-r--r--tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc235
-rw-r--r--tensorflow/compiler/tf2xla/kernels/diag_op.cc255
-rw-r--r--tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc200
-rw-r--r--tensorflow/compiler/tf2xla/kernels/fill_op.cc74
-rw-r--r--tensorflow/compiler/tf2xla/kernels/function_ops.cc110
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op.cc104
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc69
-rw-r--r--tensorflow/compiler/tf2xla/kernels/identity_op.cc39
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops.cc142
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc49
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc51
-rw-r--r--tensorflow/compiler/tf2xla/kernels/l2loss_op.cc53
-rw-r--r--tensorflow/compiler/tf2xla/kernels/lrn_ops.cc173
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc88
-rw-r--r--tensorflow/compiler/tf2xla/kernels/no_op.cc24
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pack_op.cc93
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pad_op.cc80
-rw-r--r--tensorflow/compiler/tf2xla/kernels/pooling_ops.cc374
-rw-r--r--tensorflow/compiler/tf2xla/kernels/random_ops.cc116
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.cc157
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops.h71
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc150
-rw-r--r--tensorflow/compiler/tf2xla/kernels/relu_op.cc93
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reshape_op.cc101
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc79
-rw-r--r--tensorflow/compiler/tf2xla/kernels/select_op.cc90
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sequence_ops.cc213
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc245
-rw-r--r--tensorflow/compiler/tf2xla/kernels/slice_op.cc121
-rw-r--r--tensorflow/compiler/tf2xla/kernels/softmax_op.cc152
-rw-r--r--tensorflow/compiler/tf2xla/kernels/split_op.cc208
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc223
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tile_ops.cc128
-rw-r--r--tensorflow/compiler/tf2xla/kernels/transpose_op.cc134
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unary_ops.cc70
-rw-r--r--tensorflow/compiler/tf2xla/kernels/unpack_op.cc90
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc65
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h42
-rw-r--r--tensorflow/compiler/tf2xla/literal_util_test.cc71
-rw-r--r--tensorflow/compiler/tf2xla/op_registrations.cc502
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.cc54
-rw-r--r--tensorflow/compiler/tf2xla/shape_util.h38
-rw-r--r--tensorflow/compiler/tf2xla/str_util.cc44
-rw-r--r--tensorflow/compiler/tf2xla/str_util.h46
-rw-r--r--tensorflow/compiler/tf2xla/str_util_test.cc90
-rw-r--r--tensorflow/compiler/tf2xla/type_util.cc68
-rw-r--r--tensorflow/compiler/tf2xla/type_util.h30
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc203
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h214
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc405
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h203
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc331
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h277
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc142
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.h73
-rw-r--r--tensorflow/compiler/tf2xla/xla_local_runtime_context.h55
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc253
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h174
-rw-r--r--tensorflow/compiler/xla/.clang-format3
-rw-r--r--tensorflow/compiler/xla/BUILD561
-rw-r--r--tensorflow/compiler/xla/README.md1
-rw-r--r--tensorflow/compiler/xla/array2d.cc36
-rw-r--r--tensorflow/compiler/xla/array2d.h165
-rw-r--r--tensorflow/compiler/xla/array2d_test.cc132
-rw-r--r--tensorflow/compiler/xla/array3d.h127
-rw-r--r--tensorflow/compiler/xla/array3d_test.cc93
-rw-r--r--tensorflow/compiler/xla/array4d.h272
-rw-r--r--tensorflow/compiler/xla/array4d_test.cc180
-rw-r--r--tensorflow/compiler/xla/client/BUILD175
-rw-r--r--tensorflow/compiler/xla/client/client.cc479
-rw-r--r--tensorflow/compiler/xla/client/client.h202
-rw-r--r--tensorflow/compiler/xla/client/client_library.cc107
-rw-r--r--tensorflow/compiler/xla/client/client_library.h103
-rw-r--r--tensorflow/compiler/xla/client/computation.cc67
-rw-r--r--tensorflow/compiler/xla/client/computation.h76
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc1539
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h783
-rw-r--r--tensorflow/compiler/xla/client/global_data.cc42
-rw-r--r--tensorflow/compiler/xla/client/global_data.h46
-rw-r--r--tensorflow/compiler/xla/client/lib/BUILD60
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.cc67
-rw-r--r--tensorflow/compiler/xla/client/lib/arithmetic.h45
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.cc59
-rw-r--r--tensorflow/compiler/xla/client/lib/testing.h43
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc371
-rw-r--r--tensorflow/compiler/xla/client/local_client.h263
-rw-r--r--tensorflow/compiler/xla/client/padding.cc122
-rw-r--r--tensorflow/compiler/xla/client/padding.h58
-rw-r--r--tensorflow/compiler/xla/client/padding_test.cc91
-rw-r--r--tensorflow/compiler/xla/device_util.h39
-rw-r--r--tensorflow/compiler/xla/differential_set.h63
-rw-r--r--tensorflow/compiler/xla/differential_set_test.cc51
-rw-r--r--tensorflow/compiler/xla/executable_run_options.cc70
-rw-r--r--tensorflow/compiler/xla/executable_run_options.h87
-rw-r--r--tensorflow/compiler/xla/index_util.cc126
-rw-r--r--tensorflow/compiler/xla/index_util.h69
-rw-r--r--tensorflow/compiler/xla/index_util_test.cc159
-rw-r--r--tensorflow/compiler/xla/layout_util.cc363
-rw-r--r--tensorflow/compiler/xla/layout_util.h153
-rw-r--r--tensorflow/compiler/xla/layout_util_test.cc246
-rw-r--r--tensorflow/compiler/xla/legacy_flags/BUILD267
-rw-r--r--tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc62
-rw-r--r--tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h46
-rw-r--r--tensorflow/compiler/xla/legacy_flags/backend_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/backend_flags.h46
-rw-r--r--tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h46
-rw-r--r--tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc61
-rw-r--r--tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h47
-rw-r--r--tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h47
-rw-r--r--tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc76
-rw-r--r--tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h54
-rw-r--r--tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc71
-rw-r--r--tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h51
-rw-r--r--tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc91
-rw-r--r--tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h56
-rw-r--r--tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc73
-rw-r--r--tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h54
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h47
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc62
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h48
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h47
-rw-r--r--tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc107
-rw-r--r--tensorflow/compiler/xla/legacy_flags/layout_util_flags.h62
-rw-r--r--tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc67
-rw-r--r--tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h58
-rw-r--r--tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h46
-rw-r--r--tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc206
-rw-r--r--tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h66
-rw-r--r--tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc190
-rw-r--r--tensorflow/compiler/xla/legacy_flags/service_flags.cc100
-rw-r--r--tensorflow/compiler/xla/legacy_flags/service_flags.h69
-rw-r--r--tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc63
-rw-r--r--tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h47
-rw-r--r--tensorflow/compiler/xla/legacy_flags/util_flags.cc62
-rw-r--r--tensorflow/compiler/xla/legacy_flags/util_flags.h45
-rw-r--r--tensorflow/compiler/xla/literal_util.cc989
-rw-r--r--tensorflow/compiler/xla/literal_util.h1004
-rw-r--r--tensorflow/compiler/xla/literal_util_test.cc622
-rw-r--r--tensorflow/compiler/xla/map_util.h65
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.cc92
-rw-r--r--tensorflow/compiler/xla/packed_literal_reader.h59
-rw-r--r--tensorflow/compiler/xla/port/BUILD33
-rw-r--r--tensorflow/compiler/xla/port/initialize.h39
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc133
-rw-r--r--tensorflow/compiler/xla/primitive_util.h157
-rw-r--r--tensorflow/compiler/xla/protobuf_util.cc35
-rw-r--r--tensorflow/compiler/xla/protobuf_util.h35
-rw-r--r--tensorflow/compiler/xla/ptr_util.h80
-rw-r--r--tensorflow/compiler/xla/reference_util.cc540
-rw-r--r--tensorflow/compiler/xla/reference_util.h382
-rw-r--r--tensorflow/compiler/xla/reference_util_test.cc306
-rw-r--r--tensorflow/compiler/xla/service/BUILD1216
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc938
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.h56
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc1368
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc215
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h178
-rw-r--r--tensorflow/compiler/xla/service/backend.cc237
-rw-r--r--tensorflow/compiler/xla/service/backend.h191
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc777
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h358
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc1051
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.cc259
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.h215
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc487
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.cc91
-rw-r--r--tensorflow/compiler/xla/service/channel_tracker.h94
-rw-r--r--tensorflow/compiler/xla/service/compilation_cache.cc78
-rw-r--r--tensorflow/compiler/xla/service/compilation_cache.h78
-rw-r--r--tensorflow/compiler/xla/service/compiler.cc96
-rw-r--r--tensorflow/compiler/xla/service/compiler.h172
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.cc57
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.h83
-rw-r--r--tensorflow/compiler/xla/service/computation_tracker.cc204
-rw-r--r--tensorflow/compiler/xla/service/computation_tracker.h139
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc439
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.h54
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion_test.cc1153
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD529
-rw-r--r--tensorflow/compiler/xla/service/cpu/build_defs.bzl11
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.cc220
-rw-r--r--tensorflow/compiler/xla/service/cpu/compiler_functor.h69
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc148
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization.h44
-rw-r--r--tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc146
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc631
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.h148
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc477
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h150
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc44
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h37
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc120
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h46
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc52
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.h91
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc36
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h50
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc47
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h50
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc138
-rw-r--r--tensorflow/compiler/xla/service/cpu/disassembler.cc182
-rw-r--r--tensorflow/compiler/xla/service/cpu/disassembler.h63
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc346
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h90
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc68
-rw-r--r--tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h43
-rw-r--r--tensorflow/compiler/xla/service/cpu/infeed_manager.cc72
-rw-r--r--tensorflow/compiler/xla/service/cpu/infeed_manager.h95
-rw-r--r--tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc102
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc127
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emission_utils.h32
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc1774
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h402
-rw-r--r--tensorflow/compiler/xla/service/cpu/layout_assignment.cc124
-rw-r--r--tensorflow/compiler/xla/service/cpu/layout_assignment.h41
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc365
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h124
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc43
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_conv2d.h39
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h87
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul.cc81
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_matmul.h42
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc39
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h39
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc73
-rw-r--r--tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h42
-rw-r--r--tensorflow/compiler/xla/service/cpu/sample_harness.cc75
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc189
-rw-r--r--tensorflow/compiler/xla/service/cpu/simple_orc_jit.h88
-rw-r--r--tensorflow/compiler/xla/service/cpu_transfer_manager.cc108
-rw-r--r--tensorflow/compiler/xla/service/cpu_transfer_manager.h47
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.cc77
-rw-r--r--tensorflow/compiler/xla/service/device_memory_allocator.h84
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.cc78
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h289
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h226
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc934
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h118
-rw-r--r--tensorflow/compiler/xla/service/executable.cc82
-rw-r--r--tensorflow/compiler/xla/service/executable.h168
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc95
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.h105
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc183
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h77
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD533
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.cc139
-rw-r--r--tensorflow/compiler/xla/service/gpu/buffer_allocations.h113
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding.cc443
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding.h34
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc552
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.cc324
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_thunk.h149
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_insertion.cc71
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_insertion.h36
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.cc41
-rw-r--r--tensorflow/compiler/xla/service/gpu/copy_thunk.h56
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc396
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h91
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.cc50
-rw-r--r--tensorflow/compiler/xla/service/gpu/for_thunk.h52
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc189
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.h71
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.cc335
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_compiler.h78
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc454
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h130
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc207
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.h67
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc368
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc168
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h109
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc90
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.h39
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc126
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc200
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.h72
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc645
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h405
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_context.h74
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc120
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc1745
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.cc94
-rw-r--r--tensorflow/compiler/xla/service/gpu/kernel_thunk.h86
-rw-r--r--tensorflow/compiler/xla/service/gpu/layout_assignment.cc142
-rw-r--r--tensorflow/compiler/xla/service/gpu/layout_assignment.h41
-rw-r--r--tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc85
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD88
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc103
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h51
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc489
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h43
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll141
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc65
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h50
-rw-r--r--tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc408
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.h43
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc98
-rw-r--r--tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h58
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc99
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.h75
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.cc45
-rw-r--r--tensorflow/compiler/xla/service/gpu/sequential_thunk.h54
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.cc135
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.h46
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc132
-rw-r--r--tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc52
-rw-r--r--tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h47
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk.h90
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk_schedule.cc163
-rw-r--r--tensorflow/compiler/xla/service/gpu/thunk_schedule.h93
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.cc49
-rw-r--r--tensorflow/compiler/xla/service/gpu/tuple_thunk.h60
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.cc74
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_thunk.h62
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.cc532
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer.h43
-rw-r--r--tensorflow/compiler/xla/service/gpu/while_transformer_test.cc218
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc165
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc520
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h300
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc311
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc350
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h147
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc337
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc134
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h46
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc428
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc69
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.h43
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc97
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.h71
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc507
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.h76
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc1921
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h791
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc894
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc269
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.h132
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.cc53
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_config.h92
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_test.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc164
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h107
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode_test.cc30
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass.h68
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc64
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.h66
-rw-r--r--tensorflow/compiler/xla/service/hlo_query.cc89
-rw-r--r--tensorflow/compiler/xla/service/hlo_query.h63
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc45
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification.h34
-rw-r--r--tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc205
-rw-r--r--tensorflow/compiler/xla/service/inliner.cc123
-rw-r--r--tensorflow/compiler/xla/service/inliner.h39
-rw-r--r--tensorflow/compiler/xla/service/inliner_test.cc109
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc295
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h84
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion_test.cc140
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc1334
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h302
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc486
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD154
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/README.md2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc195
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h93
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc147
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h94
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.cc274
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ir_array.h248
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc197
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h230
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc471
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/llvm_util.h228
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc103
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h79
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ops.cc100
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/ops.h79
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc543
-rw-r--r--tensorflow/compiler/xla/service/local_service.h185
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.cc39
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer.h153
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc37
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h53
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc166
-rw-r--r--tensorflow/compiler/xla/service/platform_util.h61
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.cc120
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover.h36
-rw-r--r--tensorflow/compiler/xla/service/reshape_mover_test.cc57
-rw-r--r--tensorflow/compiler/xla/service/service.cc1428
-rw-r--r--tensorflow/compiler/xla/service/service.h457
-rw-r--r--tensorflow/compiler/xla/service/session.proto91
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc1380
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h219
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc1133
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc168
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h137
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc143
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h151
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager_test.cc159
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc109
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.h41
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc149
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc495
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h268
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc544
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc2117
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h336
-rw-r--r--tensorflow/compiler/xla/service/versioned_computation_handle.h48
-rw-r--r--tensorflow/compiler/xla/service_interface.h117
-rw-r--r--tensorflow/compiler/xla/shape_layout.cc78
-rw-r--r--tensorflow/compiler/xla/shape_layout.h88
-rw-r--r--tensorflow/compiler/xla/shape_tree.h260
-rw-r--r--tensorflow/compiler/xla/shape_tree_test.cc134
-rw-r--r--tensorflow/compiler/xla/shape_util.cc1024
-rw-r--r--tensorflow/compiler/xla/shape_util.h393
-rw-r--r--tensorflow/compiler/xla/shape_util_test.cc506
-rw-r--r--tensorflow/compiler/xla/status.h46
-rw-r--r--tensorflow/compiler/xla/status_macros.cc170
-rw-r--r--tensorflow/compiler/xla/status_macros.h220
-rw-r--r--tensorflow/compiler/xla/status_macros_test.cc112
-rw-r--r--tensorflow/compiler/xla/statusor.cc46
-rw-r--r--tensorflow/compiler/xla/statusor.h300
-rw-r--r--tensorflow/compiler/xla/statusor_test.cc645
-rw-r--r--tensorflow/compiler/xla/test_helpers.cc69
-rw-r--r--tensorflow/compiler/xla/test_helpers.h355
-rw-r--r--tensorflow/compiler/xla/tests/BUILD1436
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc1662
-rw-r--r--tensorflow/compiler/xla/tests/axpy_simple_test.cc90
-rw-r--r--tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc85
-rw-r--r--tensorflow/compiler/xla/tests/batch_normalization_test.cc210
-rw-r--r--tensorflow/compiler/xla/tests/binop_scaling_test.cc157
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc179
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_test.cc286
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl149
-rw-r--r--tensorflow/compiler/xla/tests/call_test.cc115
-rw-r--r--tensorflow/compiler/xla/tests/check_execution_arity_test.cc138
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc263
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h409
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc127
-rw-r--r--tensorflow/compiler/xla/tests/codegen_test_base.cc90
-rw-r--r--tensorflow/compiler/xla/tests/codegen_test_base.h56
-rw-r--r--tensorflow/compiler/xla/tests/compilation_cache_test.cc218
-rw-r--r--tensorflow/compiler/xla/tests/compute_constant_test.cc249
-rw-r--r--tensorflow/compiler/xla/tests/concat_test.cc523
-rw-r--r--tensorflow/compiler/xla/tests/constants_test.cc193
-rw-r--r--tensorflow/compiler/xla/tests/convert_test.cc210
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc117
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc361
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc1294
-rw-r--r--tensorflow/compiler/xla/tests/copy_test.cc277
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc148
-rw-r--r--tensorflow/compiler/xla/tests/deallocation_test.cc155
-rw-r--r--tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc215
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc387
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc506
-rw-r--r--tensorflow/compiler/xla/tests/floor_ceil_test.cc128
-rw-r--r--tensorflow/compiler/xla/tests/fmax_test.cc61
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc589
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc204
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h107
-rw-r--r--tensorflow/compiler/xla/tests/inprocess_service_test.cc204
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.cc566
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util.h274
-rw-r--r--tensorflow/compiler/xla/tests/literal_test_util_test.cc102
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test.cc55
-rw-r--r--tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc111
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc220
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h146
-rw-r--r--tensorflow/compiler/xla/tests/log_test.cc75
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc589
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc179
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc74
-rw-r--r--tensorflow/compiler/xla/tests/pad_test.cc420
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc357
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc115
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc238
-rw-r--r--tensorflow/compiler/xla/tests/query_inferred_shape_test.cc61
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc506
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc445
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc168
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc77
-rw-r--r--tensorflow/compiler/xla/tests/reshape_test.cc811
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc173
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc160
-rw-r--r--tensorflow/compiler/xla/tests/round_trip_transfer_test.cc164
-rw-r--r--tensorflow/compiler/xla/tests/scalar_computations_test.cc630
-rw-r--r--tensorflow/compiler/xla/tests/select_and_scatter_test.cc395
-rw-r--r--tensorflow/compiler/xla/tests/select_test.cc276
-rw-r--r--tensorflow/compiler/xla/tests/set_return_value_test.cc116
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc277
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.h76
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.h115
-rw-r--r--tensorflow/compiler/xla/tests/transpose_test.cc203
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc415
-rw-r--r--tensorflow/compiler/xla/tests/unary_op_test.cc179
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc235
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc423
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc395
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.cc155
-rw-r--r--tensorflow/compiler/xla/text_literal_reader.h62
-rw-r--r--tensorflow/compiler/xla/text_literal_reader_test.cc58
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.cc64
-rw-r--r--tensorflow/compiler/xla/text_literal_writer.h48
-rw-r--r--tensorflow/compiler/xla/text_literal_writer_test.cc52
-rw-r--r--tensorflow/compiler/xla/tools/BUILD191
-rw-r--r--tensorflow/compiler/xla/tools/convert_computation.cc60
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc76
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc111
-rw-r--r--tensorflow/compiler/xla/tools/dumped_computation_to_text.cc83
-rw-r--r--tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc76
-rw-r--r--tensorflow/compiler/xla/tools/replay_computation.cc129
-rw-r--r--tensorflow/compiler/xla/tools/show_literal.cc45
-rw-r--r--tensorflow/compiler/xla/tools/show_signature.cc73
-rw-r--r--tensorflow/compiler/xla/tools/show_text_literal.cc52
-rw-r--r--tensorflow/compiler/xla/types.h37
-rw-r--r--tensorflow/compiler/xla/util.cc238
-rw-r--r--tensorflow/compiler/xla/util.h257
-rw-r--r--tensorflow/compiler/xla/util_test.cc89
-rw-r--r--tensorflow/compiler/xla/window_util.cc142
-rw-r--r--tensorflow/compiler/xla/window_util.h66
-rw-r--r--tensorflow/compiler/xla/xla.bzl22
-rw-r--r--tensorflow/compiler/xla/xla.proto291
-rw-r--r--tensorflow/compiler/xla/xla_data.proto714
-rw-r--r--tensorflow/contrib/compiler/BUILD1
-rw-r--r--tensorflow/core/platform/default/build_config.bzl12
-rwxr-xr-xtensorflow/tools/ci_build/builds/configured4
-rw-r--r--tensorflow/tools/pip_package/BUILD3
-rw-r--r--third_party/llvm/llvm.BUILD1
656 files changed, 138481 insertions, 2 deletions
diff --git a/configure b/configure
index 8d4b12aad2..64add33bd5 100755
--- a/configure
+++ b/configure
@@ -112,6 +112,26 @@ else
sed -i -e "s/WITH_HDFS_SUPPORT = True/WITH_HDFS_SUPPORT = False/" tensorflow/core/platform/default/build_config.bzl
fi
+## Enable XLA.
+while [ "$TF_ENABLE_XLA" == "" ]; do
+ read -p "Do you wish to build TensorFlow with the XLA just-in-time compiler (experimental)? [y/N] " INPUT
+ case $INPUT in
+ [Yy]* ) echo "XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=1;;
+ [Nn]* ) echo "No XLA JIT support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
+ "" ) echo "No XLA support will be enabled for TensorFlow"; TF_ENABLE_XLA=0;;
+ * ) echo "Invalid selection: " $INPUT;;
+ esac
+done
+
+if [ "$TF_ENABLE_XLA" == "1" ]; then
+ # Update Bazel build configuration.
+ perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
+else
+ # Update Bazel build configuration.
+ perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
+fi
+
+
# Invoke python_config and set up symlinks to python includes
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 44015ff7d9..ef04a6a88e 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -95,6 +95,26 @@ filegroup(
"//tensorflow/c:all_files",
"//tensorflow/cc:all_files",
"//tensorflow/cc/saved_model:all_files",
+ "//tensorflow/compiler/aot:all_files",
+ "//tensorflow/compiler/aot/tests:all_files",
+ "//tensorflow/compiler/jit:all_files",
+ "//tensorflow/compiler/jit/graphcycles:all_files",
+ "//tensorflow/compiler/jit/legacy_flags:all_files",
+ "//tensorflow/compiler/tests:all_files",
+ "//tensorflow/compiler/tf2xla:all_files",
+ "//tensorflow/compiler/tf2xla/kernels:all_files",
+ "//tensorflow/compiler/xla:all_files",
+ "//tensorflow/compiler/xla/client:all_files",
+ "//tensorflow/compiler/xla/client/lib:all_files",
+ "//tensorflow/compiler/xla/legacy_flags:all_files",
+ "//tensorflow/compiler/xla/port:all_files",
+ "//tensorflow/compiler/xla/service:all_files",
+ "//tensorflow/compiler/xla/service/cpu:all_files",
+ "//tensorflow/compiler/xla/service/gpu:all_files",
+ "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend:all_files",
+ "//tensorflow/compiler/xla/service/llvm_ir:all_files",
+ "//tensorflow/compiler/xla/tests:all_files",
+ "//tensorflow/compiler/xla/tools:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/bayesflow:all_files",
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
new file mode 100644
index 0000000000..c52a56b642
--- /dev/null
+++ b/tensorflow/compiler/aot/BUILD
@@ -0,0 +1,218 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+
+# Optional runtime utilities for use by code generated by tfcompile.
+cc_library(
+ name = "runtime",
+ srcs = ["runtime.cc"],
+ hdrs = ["runtime.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ ],
+)
+
+cc_test(
+ name = "runtime_test",
+ srcs = ["runtime_test.cc"],
+ deps = [
+ ":runtime",
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# Don't depend on this directly; this is only used for the benchmark test
+# generated by tf_library.
+cc_library(
+ name = "tf_library_test_main",
+ testonly = 1,
+ visibility = ["//visibility:public"],
+ deps = ["//tensorflow/core:test_main"],
+)
+
+xla_proto_library(
+ name = "tfcompile_proto",
+ srcs = ["tfcompile.proto"],
+ deps = [
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "tfcompile_lib",
+ srcs = [
+ "codegen.cc",
+ "compile.cc",
+ "flags.cc",
+ "tfcompile_util.cc",
+ ],
+ hdrs = [
+ "codegen.h",
+ "compile.h",
+ "flags.h",
+ "tfcompile_util.h",
+ ],
+ deps = [
+ ":runtime", # needed by codegen to print aligned_buffer_bytes
+ ":tfcompile_proto",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_test(
+ name = "codegen_test",
+ srcs = ["codegen_test.cc"],
+ data = ["codegen_test_h.golden"],
+ deps = [
+ ":tfcompile_lib",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "tfcompile_util_test",
+ srcs = ["tfcompile_util_test.cc"],
+ deps = [
+ ":tfcompile_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_binary(
+ name = "tfcompile",
+ visibility = ["//visibility:public"],
+ deps = [":tfcompile_main"],
+)
+
+cc_library(
+ name = "tfcompile_main",
+ srcs = ["tfcompile_main.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tfcompile_lib",
+ ":tfcompile_proto",
+ "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+# NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that
+# tfcompile.bzl correctly handles usage from outside of the package that it is
+# defined in.
+
+# A simple test of tf_library from a text protobuf, mostly to enable the
+# benchmark_test.
+tf_library(
+ name = "test_graph_tfadd",
+ testonly = 1,
+ config = "test_graph_tfadd.config.pbtxt",
+ cpp_class = "AddComp",
+ graph = "test_graph_tfadd.pbtxt",
+ tags = ["manual"],
+)
+
+# Utility library for benchmark binaries, used by the *_benchmark rules that are
+# added by the tfcompile bazel macro.
+cc_library(
+ name = "benchmark",
+ srcs = ["benchmark.cc"],
+ hdrs = ["benchmark.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ # The purpose of the benchmark library is to support building an aot
+ # binary with minimal dependencies, to demonstrate small binary sizes.
+ #
+ # KEEP THE DEPENDENCIES MINIMAL.
+ "//tensorflow/core:framework_lite",
+ ],
+)
+
+cc_library(
+ name = "benchmark_extra_android",
+ tags = [
+ "manual",
+ "notap",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_test(
+ name = "benchmark_test",
+ srcs = ["benchmark_test.cc"],
+ tags = ["manual"],
+ deps = [
+ ":benchmark",
+ ":test_graph_tfadd",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+test_suite(
+ name = "all_tests",
+ tags = ["manual"],
+ tests = [
+ ":benchmark_test",
+ ":test_graph_tfadd_test",
+ "//tensorflow/compiler/aot/tests:all_tests",
+ ],
+)
+
+exports_files([
+ "benchmark_main.template", # used by tf_library(...,gen_benchmark=True)
+ "test.cc", # used by tf_library(...,gen_test=True)
+])
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/aot/benchmark.cc b/tensorflow/compiler/aot/benchmark.cc
new file mode 100644
index 0000000000..0c5e2c103e
--- /dev/null
+++ b/tensorflow/compiler/aot/benchmark.cc
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The purpose of the benchmark library is to support building an aot binary
+// with minimal dependencies, to demonstrate small binary sizes.
+//
+// KEEP THE DEPENDENCIES MINIMAL.
+
+#include "tensorflow/compiler/aot/benchmark.h"
+
+#include <sys/time.h>
+
+#include <algorithm>
+#include <functional>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace benchmark {
+
+// Returns current wall time in micros.
+//
+// TODO(b/33546473): Refactor tensorflow::Env::NowMicros() so that we can re-use
+// the implementation without pulling in all of the Env dependencies.
+static double NowMicros() {
+ struct timeval tv;
+ gettimeofday(&tv, NULL);
+ return static_cast<uint64>(tv.tv_sec) * 1000000 + tv.tv_usec;
+}
+
+void DumpStatsToStdout(const Stats& stats) {
+ // Compute stats.
+ std::vector<int64> sorted_us(stats.per_iter_us);
+ std::sort(sorted_us.begin(), sorted_us.end());
+ const size_t count_us = sorted_us.size();
+ double sum_us = 0;
+ size_t count_us_trimmed = 0;
+ double sum_us_trimmed = 0;
+ size_t count_us_best = 0;
+ double sum_us_best = 0;
+ static constexpr float trim_ratio = 0.25;
+ static constexpr float best_ratio = 0.1;
+ const size_t count_trimmed = count_us * trim_ratio;
+ const size_t count_best = count_us * best_ratio;
+ for (size_t i = 0; i < sorted_us.size(); ++i) {
+ const int64 us = sorted_us[i];
+ sum_us += us;
+ if (i >= count_trimmed && i < count_us - count_trimmed) {
+ sum_us_trimmed += us;
+ ++count_us_trimmed;
+ }
+ if (i < count_best) {
+ sum_us_best += us;
+ ++count_us_best;
+ }
+ }
+ // Prepare nicely-formatted data.
+ const int kBufSize = 1000;
+ char buf[kBufSize];
+ snprintf(buf, kBufSize, "Mean with %2.0f%% trimmed:", trim_ratio * 100);
+ const string label_trimmed(buf);
+ snprintf(buf, kBufSize, "Mean of %2.0f%% best:", best_ratio * 100);
+ const string label_best(buf);
+ std::vector<std::pair<string, double>> groups = {
+ {"Best:", sorted_us.front()},
+ {"Worst:", sorted_us.back()},
+ {"Median:", sorted_us[count_us / 2]},
+ {"Mean:", sum_us / count_us},
+ {label_trimmed, sum_us_trimmed / count_us_trimmed},
+ {label_best, sum_us_best / count_us_best},
+ };
+ int max_label_size = 0;
+ double max_us = 0;
+ for (const auto& g : groups) {
+ if (g.first.size() > max_label_size) {
+ max_label_size = g.first.size();
+ }
+ if (g.second > max_us) {
+ max_us = g.second;
+ }
+ }
+ int max_digits = 1;
+ while (max_us >= 10.0) {
+ max_us /= 10.0;
+ ++max_digits;
+ }
+ // Dump stats out.
+ printf("Benchmark ran %zu iterations over %lld us\n", count_us,
+ stats.total_us);
+ for (const auto& g : groups) {
+ printf(" %-*s %*.3f us\n", max_label_size, g.first.c_str(), max_digits + 4,
+ g.second);
+ }
+}
+
+void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats) {
+ // If neither max_seconds or max_iters is set, stop at kDefaultMicros.
+ const int64 max_us = (options.max_micros <= 0 && options.max_iters <= 0)
+ ? Options::kDefaultMicros
+ : options.max_micros;
+ printf("Running benchmark for %lld us\n", max_us);
+ const int64 start_us = NowMicros();
+ int64 iters = 0;
+ while (true) {
+ const int64 iter_start_us = NowMicros();
+ fn();
+ const int64 end_us = NowMicros();
+ // Collect stats and decide whether to stop.
+ stats->per_iter_us.push_back(end_us - iter_start_us);
+ const int64 total_us = end_us - start_us;
+ ++iters;
+ if ((max_us > 0 && total_us >= max_us) ||
+ (options.max_iters > 0 && iters >= options.max_iters)) {
+ stats->total_us = total_us;
+ break;
+ }
+ }
+}
+
+} // namespace benchmark
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/benchmark.h b/tensorflow/compiler/aot/benchmark.h
new file mode 100644
index 0000000000..266b7fefc7
--- /dev/null
+++ b/tensorflow/compiler/aot/benchmark.h
@@ -0,0 +1,70 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Contains benchmark functions used with the code-generated benchmarks that can
+// be used to test a model on android. See also code generation rules in
+// tfcompile.bzl.
+//
+// This is separate from the built-in micro-benchmarks, because we want to:
+// 1. show a binary with minimal dependencies, to show a close-to-lower-bound
+// binary size.
+// 2. compile on Android.
+#ifndef TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
+#define TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace benchmark {
+
+// Options specifies options for benchmarks of functions generated by tfcompile.
+struct Options {
+ // kDefaultMicros specifies the default time to run the benchmark, and is used
+ // if neither max_iters nor max_micros is set.
+ static const int64 kDefaultMicros = 3000000;
+
+ int64 max_iters = 0; // Maximum iterations to run, ignored if <= 0.
+ int64 max_micros = 0; // Maximum microseconds to run, ignored if <= 0.
+};
+
+// Stats holds statistics collected during benchmarking.
+struct Stats {
+ std::vector<int64> per_iter_us; // Per-iteration deltas in us.
+ int64 total_us; // Total time in us.
+
+ Stats() : total_us(0) { per_iter_us.reserve(5000); }
+};
+
+// DumpStatsToStdout printfs to stdout stats in a multi-line human-friendly
+// form.
+void DumpStatsToStdout(const Stats& stats);
+
+// BenchmarkFn is the signature of the function generated by tfcompile.
+typedef std::function<void()> BenchmarkFn;
+
+// Benchmark runs a benchmark of the function `fn`, collecting stats in `stats`.
+// Use `options` to configure benchmarking options.
+void Benchmark(const Options& options, const BenchmarkFn& fn, Stats* stats);
+
+} // namespace benchmark
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_BENCHMARK_H_
diff --git a/tensorflow/compiler/aot/benchmark_main.template b/tensorflow/compiler/aot/benchmark_main.template
new file mode 100644
index 0000000000..a4df6ed1cf
--- /dev/null
+++ b/tensorflow/compiler/aot/benchmark_main.template
@@ -0,0 +1,51 @@
+// Generated by the tf_library build rule. DO NOT EDIT!
+//
+// This file contains the main function and logic for benchmarking code
+// generated by tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be
+// rewritten to real values before this file can be compiled.
+//
+// TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
+// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
+//
+// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
+// generates a cc_binary rule for you.
+
+// These macros must be defined before eigen files are included.
+#define EIGEN_USE_THREADS
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+
+// clang-format off
+#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
+// clang-format on
+
+#include "tensorflow/compiler/aot/benchmark.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+// Macros that expand to tokens based on the entry point name.
+// clang-format off
+#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
+// clang-format on
+
+namespace tensorflow {
+namespace tfcompile {
+
+int Main(int argc, char** argv) {
+ Eigen::ThreadPool pool(1 /* num_threads */);
+ Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
+
+ CPP_CLASS computation;
+ computation.set_thread_pool(&device);
+
+ benchmark::Options options;
+ benchmark::Stats stats;
+ benchmark::Benchmark(options, [&] { computation.Run(); }, &stats);
+ benchmark::DumpStatsToStdout(stats);
+ return 0;
+}
+
+} // namespace tfcompile
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ return tensorflow::tfcompile::Main(argc, argv);
+}
diff --git a/tensorflow/compiler/aot/benchmark_test.cc b/tensorflow/compiler/aot/benchmark_test.cc
new file mode 100644
index 0000000000..0568c5b1d5
--- /dev/null
+++ b/tensorflow/compiler/aot/benchmark_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/benchmark.h"
+
+#include "tensorflow/compiler/aot/test_graph_tfadd.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace benchmark {
+namespace {
+
+// There isn't much we can verify in a stable fashion, so we just run the
+// benchmark with max_iters, and ensure we end up with that many iter stats.
+TEST(Benchmark, Benchmark) {
+ AddComp add;
+
+ Options options;
+ options.max_iters = 1;
+ Stats stats1;
+ Benchmark(options, [&] { add.Run(); }, &stats1);
+ EXPECT_EQ(stats1.per_iter_us.size(), 1);
+
+ options.max_iters = 5;
+ Stats stats5;
+ Benchmark(options, [&] { add.Run(); }, &stats5);
+ EXPECT_EQ(stats5.per_iter_us.size(), 5);
+}
+
+} // namespace
+} // namespace benchmark
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
new file mode 100644
index 0000000000..042a72745a
--- /dev/null
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -0,0 +1,579 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/codegen.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/aot/runtime.h"
+#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/str_util.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+namespace {
+
+// Convert an XLA type into a C++ type.
+Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
+ switch (type) {
+ case xla::PRED:
+ *str = "bool";
+ break;
+ case xla::S8:
+ *str = "tensorflow::int8";
+ break;
+ case xla::S16:
+ *str = "tensorflow::int16";
+ break;
+ case xla::S32:
+ *str = "tensorflow::int32";
+ break;
+ case xla::S64:
+ *str = "tensorflow::int64";
+ break;
+ case xla::U8:
+ *str = "tensorflow::uint8";
+ break;
+ case xla::U16:
+ *str = "tensorflow::uint16";
+ break;
+ case xla::U32:
+ *str = "tensorflow::uint32";
+ break;
+ case xla::U64:
+ *str = "tensorflow::uint64";
+ break;
+ case xla::F32:
+ *str = "float";
+ break;
+ case xla::F64:
+ *str = "double";
+ break;
+ default:
+ return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
+ " has no equivalent in C++");
+ }
+ return Status::OK();
+}
+
+// total_buffer_bytes returns the sum of each size in `sizes`, skipping -1
+// values. There are `n` entries in `sizes`.
+size_t total_buffer_bytes(const intptr_t* sizes, size_t n) {
+ size_t total = 0;
+ for (size_t i = 0; i < n; ++i) {
+ if (sizes[i] != -1) {
+ total += sizes[i];
+ }
+ }
+ return total;
+}
+
+// Fills in arg_sizes with the byte size of each positional arg.
+Status ComputeArgSizes(const CompileResult& compile_result,
+ std::vector<int64>* arg_sizes) {
+ const xla::ProgramShape& ps = compile_result.program_shape;
+ for (int i = 0; i < ps.parameters_size(); ++i) {
+ if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) {
+ // If the compiled function needs a XlaLocalRuntimeContext* arg, it's
+ // always last, and must be represented as an opaque type.
+ const xla::PrimitiveType type = ps.parameters(i).element_type();
+ if (type != xla::OPAQUE) {
+ return errors::InvalidArgument(
+ "expected final context arg to be opaque, but got type: ",
+ xla::PrimitiveType_Name(type), ", from program shape: ",
+ xla::ShapeUtil::HumanString(ps));
+ }
+ arg_sizes->push_back(-1);
+ } else {
+ arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf(
+ ps.parameters(i), compile_result.pointer_size));
+ }
+ }
+ return Status::OK();
+}
+
+// Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
+// are used to generate methods for args and results.
+Status AddRewritesForShape(int i, const xla::Shape& shape,
+ std::vector<std::pair<string, string>>* rewrites) {
+ string type;
+ TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
+ std::vector<string> dim_vars;
+ string dim_sizes, indices;
+ if (xla::ShapeUtil::Rank(shape) == 0 ||
+ (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
+ dim_sizes = "[1]";
+ indices = "[0]";
+ } else {
+ for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
+ dim_vars.push_back(strings::StrCat("size_t dim", dim));
+ dim_sizes += strings::StrCat("[", shape.dimensions(dim), "]");
+ indices += strings::StrCat("[dim", dim, "]");
+ }
+ }
+ rewrites->push_back({"{{I}}", strings::StrCat(i)});
+ rewrites->push_back({"{{TYPE}}", type});
+ rewrites->push_back({"{{DIM_VARS}}", str_util::Join(dim_vars, ", ")});
+ rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
+ rewrites->push_back({"{{INDICES}}", indices});
+ return Status::OK();
+}
+
+// Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
+// for the name. Note that the rewriting strategy is roughly O(N*M), where N is
+// the size of the code and M is the number of rewrites. It's fine for now
+// since N and M are pretty small.
+//
+// TODO(toddw): If this becomes a problem, we should be able to change the
+// algorithm to O(N) by using a state machine, e.g. regexps or a real
+// text-templating mechanism.
+string RewriteWithName(const string& name, string code,
+ const std::vector<std::pair<string, string>>& rewrites) {
+ str_util::ReplaceAllPairs(&code, rewrites);
+ str_util::ReplaceAll(&code, "{{NAME}}", name);
+ return code;
+}
+
+// Generate methods for args (inputs).
+Status GenArgMethods(const Config& config, const xla::ProgramShape& ps,
+ const CompileResult& compile_result, string* methods) {
+ *methods += R"(
+ void** args() { return args_; }
+ const void *const *args() const { return args_; }
+)";
+ size_t num_args = ps.parameters_size();
+ if (compile_result.has_context_arg) {
+ // If the compiled function needs a XlaLocalRuntimeContext* arg, it's
+ // always last, and is set in the class constructor.
+ num_args--;
+ }
+ if (config.feed_size() != num_args) {
+ return errors::InvalidArgument("mismatch between feed_size(",
+ config.feed_size(), ") and num_args(",
+ num_args, ")");
+ }
+ for (int i = 0; i < num_args; ++i) {
+ std::vector<std::pair<string, string>> rewrites;
+ TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
+ const string code = R"(
+ void set_arg{{NAME}}_data(void* data) {
+ args_[{{I}}] = data;
+ }
+ {{TYPE}}* arg{{NAME}}_data() {
+ return static_cast<{{TYPE}}*>(args_[{{I}}]);
+ }
+ {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
+ return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
+ args_[{{I}}])){{INDICES}};
+ }
+ const {{TYPE}}* arg{{NAME}}_data() const {
+ return static_cast<const {{TYPE}}*>(args_[{{I}}]);
+ }
+ const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
+ return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
+ args_[{{I}}])){{INDICES}};
+ }
+)";
+ *methods += RewriteWithName(strings::StrCat(i), code, rewrites);
+ if (!config.feed(i).name().empty()) {
+ *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
+ }
+ }
+ return Status::OK();
+}
+
+// Generate methods for results (outputs).
+Status GenResultMethods(const Config& config, const xla::ProgramShape& ps,
+ string* methods) {
+ if (ps.result().element_type() != xla::TUPLE) {
+ // Non-tuple (i.e. single-result) case.
+ if (config.fetch_size() != 1) {
+ return errors::InvalidArgument(
+ "non-tuple result implies 1 fetch, but got ", config.fetch_size(),
+ " fetches");
+ }
+ *methods += R"(
+ void** results() { return temps_ + kResultIndex; }
+ const void *const *results() const { return temps_ + kResultIndex; }
+)";
+ std::vector<std::pair<string, string>> rewrites;
+ TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites));
+ const string code = R"(
+ {{TYPE}}* result{{NAME}}_data() {
+ return static_cast<{{TYPE}}*>(temps_[kResultIndex]);
+ }
+ {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
+ return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
+ temps_[kResultIndex])){{INDICES}};
+ }
+ const {{TYPE}}* result{{NAME}}_data() const {
+ return static_cast<const {{TYPE}}*>(temps_[kResultIndex]);
+ }
+ const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
+ return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
+ temps_[kResultIndex])){{INDICES}};
+ }
+)";
+ *methods += RewriteWithName("0", code, rewrites);
+ if (!config.fetch(0).name().empty()) {
+ *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites);
+ }
+ return Status::OK();
+ }
+ // Tuple (i.e. multi-result) case.
+ if (config.fetch_size() != ps.result().tuple_shapes_size()) {
+ return errors::InvalidArgument("mismatch between fetch_size(",
+ config.feed_size(), ") and tuple_size(",
+ ps.result().tuple_shapes_size(), ")");
+ }
+ *methods += R"(
+ void** results() {
+ return static_cast<void**>(temps_[kResultIndex]);
+ }
+ const void *const *results() const {
+ return static_cast<const void *const *>(temps_[kResultIndex]);
+ }
+)";
+ for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
+ std::vector<std::pair<string, string>> rewrites;
+ TF_RETURN_IF_ERROR(
+ AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
+ string code = R"(
+ {{TYPE}}* result{{NAME}}_data() {
+ return static_cast<{{TYPE}}*>(
+ static_cast<void**>(temps_[kResultIndex])[{{I}}]);
+ }
+ {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
+ return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
+ static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
+ }
+ const {{TYPE}}* result{{NAME}}_data() const {
+ return static_cast<{{TYPE}}*>(
+ static_cast<void**>(temps_[kResultIndex])[{{I}}]);
+ }
+ const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
+ return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
+ static_cast<void**>(temps_[kResultIndex])[{{I}}])){{INDICES}};
+ }
+)";
+ *methods += RewriteWithName(strings::StrCat(i), code, rewrites);
+ if (!config.fetch(i).name().empty()) {
+ *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status GenerateHeader(const HeaderOpts& opts, const Config& config,
+ const CompileResult& compile_result, string* header) {
+ TF_RETURN_IF_ERROR(ValidateConfig(config));
+ const int64 result_index = compile_result.aot->result_buffer_index();
+ const xla::BufferSizes& temp_sizes = compile_result.aot->buffer_sizes();
+ if (result_index < 0 || result_index > temp_sizes.size()) {
+ return errors::InvalidArgument("result index: ", result_index,
+ " is outside the range of temp sizes: [0,",
+ temp_sizes.size(), ")");
+ }
+
+ // Compute sizes and generate methods.
+ std::vector<int64> arg_sizes;
+ TF_RETURN_IF_ERROR(ComputeArgSizes(compile_result, &arg_sizes));
+ const xla::ProgramShape& ps = compile_result.program_shape;
+ string methods_arg, methods_result;
+ TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
+ TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
+ const std::vector<intptr_t> iarg(arg_sizes.begin(), arg_sizes.end());
+ const std::vector<intptr_t> itemp(temp_sizes.begin(), temp_sizes.end());
+ const size_t arg_bytes_aligned =
+ runtime::aligned_buffer_bytes(iarg.data(), iarg.size());
+ const size_t arg_bytes_total = total_buffer_bytes(iarg.data(), iarg.size());
+ const size_t temp_bytes_aligned =
+ runtime::aligned_buffer_bytes(itemp.data(), itemp.size());
+ const size_t temp_bytes_total =
+ total_buffer_bytes(itemp.data(), itemp.size());
+
+ // Create rewrite strings for the optional context arg.
+ string context_include;
+ string context_set_arg, context_set_thread_pool, context_member_var;
+ string run_result = "true";
+ string error_msg = "tensorflow::string()";
+ if (compile_result.has_context_arg) {
+ // NOTE: Extra spaces and newlines are used to ensure nice formatting.
+ context_include =
+ "#include "
+ "\"tensorflow/compiler/tf2xla/"
+ "xla_local_runtime_context.h\"\n";
+ context_set_arg = " args_[kNumArgs-1] = &context_;\n";
+ context_set_thread_pool = " context_.thread_pool = pool;\n";
+ context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n";
+ run_result = "!context_.error";
+ error_msg = "context_.error_msg";
+ }
+
+ // Create rewrite strings for namespace start and end.
+ string ns_start;
+ for (const string& n : opts.namespaces) {
+ ns_start += strings::StrCat("namespace ", n, " {\n");
+ }
+ ns_start += "\n";
+ string ns_end("\n");
+ for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
+ const string& n = opts.namespaces[i];
+ ns_end += strings::StrCat("} // end namespace ", n, "\n");
+ }
+
+ // Use a poor-man's text templating mechanism; first populate the full header
+ // with placeholder tokens, and then rewrite the tokens with real values.
+ *header =
+ R"(// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
+//
+// This header was generated via ahead-of-time compilation of a TensorFlow
+// graph. An object file corresponding to this header was also generated.
+// This header gives access to the functionality in that object file.
+//
+// clang-format off
+
+#ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
+#define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
+
+{{CONTEXT_INCLUDE}}
+#include "tensorflow/compiler/aot/runtime.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace Eigen { class ThreadPoolDevice; }
+
+// (Implementation detail) Entry point to the function in the object file.
+extern "C" void {{ENTRY}}(
+ void* result, xla::ExecutableRunOptions* run_options,
+ void** args, void** temps);
+
+{{NS_START}}
+// {{CLASS}} represents a computation previously specified in a
+// TensorFlow graph, now compiled into executable code. Usage example:
+//
+// {{CLASS}} computation;
+// // ...set args using computation.argN methods
+// CHECK(computation.Run());
+// // ...inspect results using computation.resultN methods
+//
+// The Run method invokes the actual computation, with inputs read from arg
+// buffers, and outputs written to result buffers. Each Run call may also use
+// a set of temporary buffers for the computation.
+//
+// By default each instance of this class manages its own arg, result and temp
+// buffers. The AllocMode constructor parameter may be used to modify the
+// buffer allocation strategy.
+//
+// Under the default allocation strategy, this class is thread-compatible:
+// o Calls to non-const methods require exclusive access to the object.
+// o Concurrent calls to const methods are OK, if those calls are made while
+// it is guaranteed that no thread may call a non-const method.
+//
+// The logical function signature is:
+// {{PROGRAM_SHAPE}}
+//
+// Memory stats:
+// arg bytes total: {{ARG_BYTES_TOTAL}}
+// arg bytes aligned: {{ARG_BYTES_ALIGNED}}
+// temp bytes total: {{TEMP_BYTES_TOTAL}}
+// temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
+class {{CLASS}} {
+ public:
+ // Number of input arguments for the compiled computation.
+ static constexpr size_t kNumArgs = {{ARG_NUM}};
+
+ // Byte size of each argument buffer. There are kNumArgs entries.
+ static const intptr_t* ArgSizes() {
+ static constexpr intptr_t kArgSizes[kNumArgs] = {{{ARG_SIZES}}};
+ return kArgSizes;
+ }
+
+ // AllocMode controls the buffer allocation mode.
+ enum class AllocMode {
+ // Allocate all buffers - args, results and temps.
+ ARGS_RESULTS_AND_TEMPS,
+
+ // Only allocate result and temp buffers.
+ // Use set_argN_data to set argument buffers before Run is called.
+ RESULTS_AND_TEMPS_ONLY,
+ };
+
+ {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
+ if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
+ alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
+ }
+{{CONTEXT_SET_ARG}}
+ alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
+ }
+
+ ~{{CLASS}}() {
+ tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
+ tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
+ }
+
+ // Sets the thread pool to use during the Run call.
+ {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
+ run_options_.set_intra_op_thread_pool(pool);
+{{CONTEXT_SET_THREAD_POOL}}
+ return *this;
+ }
+
+ // Runs the computation, with inputs read from arg buffers, and outputs
+ // written to result buffers. Returns true on success and false on failure.
+ bool Run() {
+ {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_);
+ return {{RUN_RESULT}};
+ }
+
+ // Returns the error message from the previous failed Run call.
+ tensorflow::string error_msg() const { return {{ERROR_MSG}}; }
+
+ // Arg methods for managing input buffers. Buffers are in row-major order.
+ // There is a set of methods for each positional argument, with the following
+ // general form:
+ //
+ // void set_argN_data(void* data)
+ // Sets the buffer of type T for positional argument N. May be called in
+ // any AllocMode. Must be called before Run to have an affect. Must be
+ // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
+ // to set the argument buffers.
+ //
+ // T* argN_data()
+ // Returns the buffer of type T for positional argument N.
+ //
+ // T& argN(...dim indices...)
+ // Returns a reference to the value of type T for positional argument N,
+ // with dim indices specifying which value. No bounds checking is performed
+ // on dim indices.
+ //
+ // void** args()
+ // Returns an array of argument buffers, where args()[N] is the buffer for
+ // positional argument N.
+{{METHODS_ARG}}
+
+ // Result methods for managing output buffers. Buffers are in row-major order.
+ // Must only be called after a successful Run call. There is a set of methods
+ // for each positional result, with the following general form:
+ //
+ // T* resultN_data()
+ // Returns the buffer of type T for positional result N.
+ //
+ // T& resultN(...dim indices...)
+ // Returns a reference to the value of type T for positional result N,
+ // with dim indices specifying which value. No bounds checking is performed
+ // on dim indices.
+ //
+ // void** results()
+ // Returns an array of result buffers, where results()[N] is the buffer for
+ // positional result N.
+ //
+ // Unlike the arg methods, there is no set_resultN_data method. The result
+ // buffers are managed internally, and may change after each call to Run.
+{{METHODS_RESULT}}
+
+ private:
+ // Number of result and temporary buffers for the compiled computation.
+ static constexpr size_t kNumTemps = {{TEMP_NUM}};
+ // The 0-based index of the result in the temporary buffers.
+ static constexpr size_t kResultIndex = {{RESULT_INDEX}};
+
+ // Byte size of each result / temporary buffer. There are kNumTemps entries.
+ static const intptr_t* TempSizes() {
+ static constexpr intptr_t kTempSizes[kNumTemps] = {{{TEMP_SIZES}}};
+ return kTempSizes;
+ }
+
+ void* args_[kNumArgs];
+ void* temps_[kNumTemps];
+ void* alloc_args_ = nullptr;
+ void* alloc_temps_ = nullptr;
+ xla::ExecutableRunOptions run_options_;
+{{CONTEXT_MEMBER_VAR}}
+
+ TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}});
+};
+{{NS_END}}
+
+#endif // TFCOMPILE_GENERATED_{{ENTRY}}_H_
+
+// clang-format on
+)";
+ // The replacement strategy is naive, but good enough for our purposes.
+ const std::vector<std::pair<string, string>> rewrites = {
+ {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)},
+ {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)},
+ {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())},
+ {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")},
+ {"{{CLASS}}", opts.class_name},
+ {"{{CONTEXT_INCLUDE}}\n", context_include},
+ {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var},
+ {"{{CONTEXT_SET_ARG}}\n", context_set_arg},
+ {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool},
+ {"{{ENTRY}}", compile_result.entry_point},
+ {"{{ERROR_MSG}}", error_msg},
+ {"{{METHODS_ARG}}\n", methods_arg},
+ {"{{METHODS_RESULT}}\n", methods_result},
+ {"{{NS_END}}\n", ns_end},
+ {"{{NS_START}}\n", ns_start},
+ {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
+ {"{{RESULT_INDEX}}", strings::StrCat(result_index)},
+ {"{{RUN_RESULT}}", run_result},
+ {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)},
+ {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)},
+ {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())},
+ {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")},
+ };
+ str_util::ReplaceAllPairs(header, rewrites);
+ return Status::OK();
+}
+
+Status ParseCppClass(const string& cpp_class, string* class_name,
+ std::vector<string>* namespaces) {
+ class_name->clear();
+ namespaces->clear();
+ size_t begin = 0;
+ size_t end = 0;
+ while ((end = cpp_class.find("::", begin)) != string::npos) {
+ const string ns = cpp_class.substr(begin, end - begin);
+ TF_RETURN_IF_ERROR(ValidateCppIdent(
+ ns, "in namespace component of cpp_class: " + cpp_class));
+ namespaces->push_back(ns);
+ begin = end + 2; // +2 to skip the two colons
+ }
+ const string name = cpp_class.substr(begin);
+ TF_RETURN_IF_ERROR(
+ ValidateCppIdent(name, "in class name of cpp_class: " + cpp_class));
+ *class_name = name;
+ return Status::OK();
+}
+
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h
new file mode 100644
index 0000000000..7217c57739
--- /dev/null
+++ b/tensorflow/compiler/aot/codegen.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_AOT_CODEGEN_H_
+#define TENSORFLOW_COMPILER_AOT_CODEGEN_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/aot/compile.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+// HeaderOpts specifies options for header-file generation.
+struct HeaderOpts {
+ // The name of the generated C++ class, wrapping the generated function.
+ string class_name;
+
+ // Namespaces specifies a list of C++ namespaces to add to the generated
+ // header. If empty, all symbols will be in the global namespace.
+ std::vector<string> namespaces;
+};
+
+// GenerateHeader uses the meta-information from compile_result to generate a
+// C++ header giving access to the function in the generated object file. The
+// header includes API usage documentation.
+Status GenerateHeader(const HeaderOpts& opts, const Config& config,
+ const CompileResult& compile_result, string* header);
+
+// ParseCppClass parses `cpp_class` into its `class_name` and `namespaces`
+// components. The syntax is [[<optional_namespace>::],...]<class_name>. This
+// mirrors the C++ syntax for referring to a class, where multiple namespaces
+// may precede the class name, separated by double-colons.
+Status ParseCppClass(const string& cpp_class, string* class_name,
+ std::vector<string>* namespaces);
+
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_CODEGEN_H_
diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc
new file mode 100644
index 0000000000..e3f76f3666
--- /dev/null
+++ b/tensorflow/compiler/aot/codegen_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/codegen.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace {
+
+class ParseCppClassTest : public ::testing::Test {
+ protected:
+ void ExpectOK(const string& cpp_class, const string& want_class_name,
+ const std::vector<string>& want_namespaces) {
+ string class_name;
+ std::vector<string> namespaces;
+ TF_EXPECT_OK(ParseCppClass(cpp_class, &class_name, &namespaces));
+ EXPECT_EQ(class_name, want_class_name);
+ EXPECT_EQ(namespaces, want_namespaces);
+ }
+
+ void ExpectFail(const string& cpp_class) {
+ string class_name;
+ std::vector<string> namespaces;
+ EXPECT_NE(ParseCppClass(cpp_class, &class_name, &namespaces), Status::OK());
+ }
+};
+
+TEST_F(ParseCppClassTest, ParseOK) {
+ ExpectOK("MyClass", "MyClass", {});
+ ExpectOK("_MyClass", "_MyClass", {});
+ ExpectOK("a::MyClass", "MyClass", {"a"});
+ ExpectOK("a::foo::MyClass", "MyClass", {"a", "foo"});
+ ExpectOK("a::foo::b::MyClass", "MyClass", {"a", "foo", "b"});
+ ExpectOK("a::foo::b::bar::MyClass", "MyClass", {"a", "foo", "b", "bar"});
+ ExpectOK("foo::MyClass", "MyClass", {"foo"});
+ ExpectOK("_foo::MyClass", "MyClass", {"_foo"});
+ ExpectOK("_foo::_MyClass", "_MyClass", {"_foo"});
+ // Make sure we didn't skip a valid letter or digit
+ string ident;
+ for (char c = 'a'; c <= 'z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = 'A'; c <= 'Z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = '0'; c <= '9'; c++) {
+ ident.append(1, c);
+ }
+ ident += "_";
+ ExpectOK(ident, ident, {});
+ ExpectOK(ident + "::" + ident, ident, {ident});
+ ExpectOK(ident + "::" + ident + "::" + ident, ident, {ident, ident});
+}
+
+TEST_F(ParseCppClassTest, ParseFail) {
+ ExpectFail("");
+ ExpectFail("::");
+ ExpectFail("::MyClass"); // valid C++, but disallowed for simpler code.
+ ExpectFail("0");
+ ExpectFail("a.b");
+ ExpectFail("a:b");
+ ExpectFail("good::.bad");
+ ExpectFail("good:::bad");
+ ExpectFail("good:: bad");
+ ExpectFail("good::0bad");
+}
+
+TEST(GenerateHeader, Golden) {
+ HeaderOpts opts;
+ opts.class_name = "MyClass";
+ opts.namespaces = {"foo", "bar"};
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("feed0");
+ feed->set_name("myfeed");
+ feed = config.add_feed();
+ feed->mutable_id()->set_node_name("feed1");
+ Fetch* fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("fetch0");
+ fetch->set_name("myfetch");
+ CompileResult compile_result;
+ compile_result.aot.reset(
+ new xla::cpu::CpuAotCompilationResult({}, {1, -1, 2, -1, 3, 120}, 5));
+ compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
+ {
+ xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
+ xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
+ xla::ShapeUtil::MakeOpaqueShape(),
+ },
+ xla::ShapeUtil::MakeShape(xla::U32, {5, 6}));
+ compile_result.has_context_arg = true;
+ compile_result.entry_point = "entry_point";
+ compile_result.pointer_size = 8;
+ string header;
+ TF_EXPECT_OK(GenerateHeader(opts, config, compile_result, &header));
+
+ // Compare against the golden file.
+ const string golden_name = io::JoinPath(testing::TensorFlowSrcRoot(),
+ "compiler/aot/codegen_test_h.golden");
+ // To update the golden file, flip update_golden to true and run the
+ // following:
+ // bazel test --test_strategy=local \
+ // third_party/tensorflow/compiler/aot:codegen_test
+ const bool update_golden = false;
+ if (update_golden) {
+ TF_EXPECT_OK(WriteStringToFile(Env::Default(), golden_name, header));
+ }
+ string golden_data;
+ TF_EXPECT_OK(ReadFileToString(Env::Default(), golden_name, &golden_data));
+ EXPECT_EQ(header, golden_data);
+}
+
+} // namespace
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
new file mode 100644
index 0000000000..46d7c03006
--- /dev/null
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -0,0 +1,268 @@
+// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
+//
+// This header was generated via ahead-of-time compilation of a TensorFlow
+// graph. An object file corresponding to this header was also generated.
+// This header gives access to the functionality in that object file.
+//
+// clang-format off
+
+#ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
+#define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard)
+
+#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
+#include "tensorflow/compiler/aot/runtime.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace Eigen { class ThreadPoolDevice; }
+
+// (Implementation detail) Entry point to the function in the object file.
+extern "C" void entry_point(
+ void* result, xla::ExecutableRunOptions* run_options,
+ void** args, void** temps);
+
+namespace foo {
+namespace bar {
+
+// MyClass represents a computation previously specified in a
+// TensorFlow graph, now compiled into executable code. Usage example:
+//
+// MyClass computation;
+// // ...set args using computation.argN methods
+// CHECK(computation.Run());
+// // ...inspect results using computation.resultN methods
+//
+// The Run method invokes the actual computation, with inputs read from arg
+// buffers, and outputs written to result buffers. Each Run call may also use
+// a set of temporary buffers for the computation.
+//
+// By default each instance of this class manages its own arg, result and temp
+// buffers. The AllocMode constructor parameter may be used to modify the
+// buffer allocation strategy.
+//
+// Under the default allocation strategy, this class is thread-compatible:
+// o Calls to non-const methods require exclusive access to the object.
+// o Concurrent calls to const methods are OK, if those calls are made while
+// it is guaranteed that no thread may call a non-const method.
+//
+// The logical function signature is:
+// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6]
+//
+// Memory stats:
+// arg bytes total: 104
+// arg bytes aligned: 128
+// temp bytes total: 126
+// temp bytes aligned: 224
+class MyClass {
+ public:
+ // Number of input arguments for the compiled computation.
+ static constexpr size_t kNumArgs = 3;
+
+ // Byte size of each argument buffer. There are kNumArgs entries.
+ static const intptr_t* ArgSizes() {
+ static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1};
+ return kArgSizes;
+ }
+
+ // AllocMode controls the buffer allocation mode.
+ enum class AllocMode {
+ // Allocate all buffers - args, results and temps.
+ ARGS_RESULTS_AND_TEMPS,
+
+ // Only allocate result and temp buffers.
+ // Use set_argN_data to set argument buffers before Run is called.
+ RESULTS_AND_TEMPS_ONLY,
+ };
+
+ MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) {
+ if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
+ alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ ArgSizes(), kNumArgs, args_, false /* annotate_initialized */);
+ }
+ args_[kNumArgs-1] = &context_;
+ alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
+ TempSizes(), kNumTemps, temps_, true /* annotate_initialized */);
+ }
+
+ ~MyClass() {
+ tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_);
+ tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
+ }
+
+ // Sets the thread pool to use during the Run call.
+ MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
+ run_options_.set_intra_op_thread_pool(pool);
+ context_.thread_pool = pool;
+ return *this;
+ }
+
+ // Runs the computation, with inputs read from arg buffers, and outputs
+ // written to result buffers. Returns true on success and false on failure.
+ bool Run() {
+ entry_point(temps_[kResultIndex], &run_options_, args_, temps_);
+ return !context_.error;
+ }
+
+ // Returns the error message from the previous failed Run call.
+ tensorflow::string error_msg() const { return context_.error_msg; }
+
+ // Arg methods for managing input buffers. Buffers are in row-major order.
+ // There is a set of methods for each positional argument, with the following
+ // general form:
+ //
+ // void set_argN_data(void* data)
+ // Sets the buffer of type T for positional argument N. May be called in
+ // any AllocMode. Must be called before Run to have an affect. Must be
+ // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
+ // to set the argument buffers.
+ //
+ // T* argN_data()
+ // Returns the buffer of type T for positional argument N.
+ //
+ // T& argN(...dim indices...)
+ // Returns a reference to the value of type T for positional argument N,
+ // with dim indices specifying which value. No bounds checking is performed
+ // on dim indices.
+ //
+ // void** args()
+ // Returns an array of argument buffers, where args()[N] is the buffer for
+ // positional argument N.
+
+ void** args() { return args_; }
+ const void *const *args() const { return args_; }
+
+ void set_arg0_data(void* data) {
+ args_[0] = data;
+ }
+ float* arg0_data() {
+ return static_cast<float*>(args_[0]);
+ }
+ float& arg0(size_t dim0, size_t dim1) {
+ return (*static_cast<float(*)[1][2]>(
+ args_[0]))[dim0][dim1];
+ }
+ const float* arg0_data() const {
+ return static_cast<const float*>(args_[0]);
+ }
+ const float& arg0(size_t dim0, size_t dim1) const {
+ return (*static_cast<const float(*)[1][2]>(
+ args_[0]))[dim0][dim1];
+ }
+
+ void set_arg_myfeed_data(void* data) {
+ args_[0] = data;
+ }
+ float* arg_myfeed_data() {
+ return static_cast<float*>(args_[0]);
+ }
+ float& arg_myfeed(size_t dim0, size_t dim1) {
+ return (*static_cast<float(*)[1][2]>(
+ args_[0]))[dim0][dim1];
+ }
+ const float* arg_myfeed_data() const {
+ return static_cast<const float*>(args_[0]);
+ }
+ const float& arg_myfeed(size_t dim0, size_t dim1) const {
+ return (*static_cast<const float(*)[1][2]>(
+ args_[0]))[dim0][dim1];
+ }
+
+ void set_arg1_data(void* data) {
+ args_[1] = data;
+ }
+ tensorflow::int64* arg1_data() {
+ return static_cast<tensorflow::int64*>(args_[1]);
+ }
+ tensorflow::int64& arg1(size_t dim0, size_t dim1) {
+ return (*static_cast<tensorflow::int64(*)[3][4]>(
+ args_[1]))[dim0][dim1];
+ }
+ const tensorflow::int64* arg1_data() const {
+ return static_cast<const tensorflow::int64*>(args_[1]);
+ }
+ const tensorflow::int64& arg1(size_t dim0, size_t dim1) const {
+ return (*static_cast<const tensorflow::int64(*)[3][4]>(
+ args_[1]))[dim0][dim1];
+ }
+
+ // Result methods for managing output buffers. Buffers are in row-major order.
+ // Must only be called after a successful Run call. There is a set of methods
+ // for each positional result, with the following general form:
+ //
+ // T* resultN_data()
+ // Returns the buffer of type T for positional result N.
+ //
+ // T& resultN(...dim indices...)
+ // Returns a reference to the value of type T for positional result N,
+ // with dim indices specifying which value. No bounds checking is performed
+ // on dim indices.
+ //
+ // void** results()
+ // Returns an array of result buffers, where results()[N] is the buffer for
+ // positional result N.
+ //
+ // Unlike the arg methods, there is no set_resultN_data method. The result
+ // buffers are managed internally, and may change after each call to Run.
+
+ void** results() { return temps_ + kResultIndex; }
+ const void *const *results() const { return temps_ + kResultIndex; }
+
+ tensorflow::uint32* result0_data() {
+ return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
+ }
+ tensorflow::uint32& result0(size_t dim0, size_t dim1) {
+ return (*static_cast<tensorflow::uint32(*)[5][6]>(
+ temps_[kResultIndex]))[dim0][dim1];
+ }
+ const tensorflow::uint32* result0_data() const {
+ return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
+ }
+ const tensorflow::uint32& result0(size_t dim0, size_t dim1) const {
+ return (*static_cast<const tensorflow::uint32(*)[5][6]>(
+ temps_[kResultIndex]))[dim0][dim1];
+ }
+
+ tensorflow::uint32* result_myfetch_data() {
+ return static_cast<tensorflow::uint32*>(temps_[kResultIndex]);
+ }
+ tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) {
+ return (*static_cast<tensorflow::uint32(*)[5][6]>(
+ temps_[kResultIndex]))[dim0][dim1];
+ }
+ const tensorflow::uint32* result_myfetch_data() const {
+ return static_cast<const tensorflow::uint32*>(temps_[kResultIndex]);
+ }
+ const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const {
+ return (*static_cast<const tensorflow::uint32(*)[5][6]>(
+ temps_[kResultIndex]))[dim0][dim1];
+ }
+
+ private:
+ // Number of result and temporary buffers for the compiled computation.
+ static constexpr size_t kNumTemps = 6;
+ // The 0-based index of the result in the temporary buffers.
+ static constexpr size_t kResultIndex = 5;
+
+ // Byte size of each result / temporary buffer. There are kNumTemps entries.
+ static const intptr_t* TempSizes() {
+ static constexpr intptr_t kTempSizes[kNumTemps] = {1, -1, 2, -1, 3, 120};
+ return kTempSizes;
+ }
+
+ void* args_[kNumArgs];
+ void* temps_[kNumTemps];
+ void* alloc_args_ = nullptr;
+ void* alloc_temps_ = nullptr;
+ xla::ExecutableRunOptions run_options_;
+ tensorflow::XlaLocalRuntimeContext context_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MyClass);
+};
+
+} // end namespace bar
+} // end namespace foo
+
+#endif // TFCOMPILE_GENERATED_entry_point_H_
+
+// clang-format on
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
new file mode 100644
index 0000000000..50e596786a
--- /dev/null
+++ b/tensorflow/compiler/aot/compile.cc
@@ -0,0 +1,416 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/compile.h"
+
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/aot/flags.h"
+#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+const char* const kArgOp = "_Arg";
+const char* const kRetvalOp = "_Retval";
+const char* const kFeedIdAttr = "_feed_id";
+const char* const kFetchIdAttr = "_fetch_id";
+const char* const kShapeAttr = "_shape";
+const char* const kDebugNameAttr = "_debug_name";
+
+namespace {
+
+Status DumpGraph(const MainFlags& flags, const string& name,
+ const Graph& graph) {
+ if (flags.debug_dir.empty()) {
+ return Status::OK();
+ }
+ GraphDef graph_def;
+ graph.ToGraphDef(&graph_def);
+ string file = io::JoinPath(flags.debug_dir, name + ".pbtxt");
+ return WriteTextProto(Env::Default(), file, graph_def);
+}
+
+string TensorIdToString(const TensorId& id) {
+ return strings::StrCat(id.node_name(), ":", id.output_index());
+}
+
+typedef std::unordered_map<string, Node*> NodeMap;
+
+// Each feed id identifies the positional output of some node, which may consist
+// of multiple edges. For each feed node, replaces all matching edges so that
+// they point from a new _Arg node instead.
+Status AddArgNodes(Graph* graph, const NodeMap& node_map,
+ const protobuf::RepeatedPtrField<Feed>& feeds) {
+ for (int arg_index = 0; arg_index < feeds.size(); ++arg_index) {
+ const Feed& feed = feeds[arg_index];
+ const TensorId& id = feed.id();
+ auto it = node_map.find(id.node_name());
+ if (it == node_map.end()) {
+ return errors::NotFound("Can't find feed id: ", TensorIdToString(id));
+ }
+ const Node* feed_node = it->second;
+ if (id.output_index() >= feed_node->num_outputs()) {
+ return errors::InvalidArgument("Invalid feed id: ", TensorIdToString(id),
+ ", output index should be < ",
+ feed_node->num_outputs());
+ }
+ // TODO(toddw): Invoke shape inference on the graph and add a "_shape" attr
+ // if we can determine it. That way the graph will be initialized with
+ // whatever shapes we can infer, while the user can still explicitly specify
+ // or override them.
+ Node* arg_node = nullptr;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_arg_", arg_index), kArgOp)
+ .Attr("T", BaseType(feed_node->output_type(id.output_index())))
+ .Attr("index", arg_index)
+ .Attr(kFeedIdAttr, TensorIdToString(id))
+ .Attr(kShapeAttr, TensorShape(feed.shape()))
+ .Attr(kDebugNameAttr, feed.name())
+ .Finalize(graph, &arg_node));
+ // Collects out-edges from the feed node that have a matching edge index;
+ // these will be replaced with edges from the arg node instead. Also
+ // replaces all control edges from Placeholder feed nodes; similar code
+ // exists in subgraph::RewriteGraphForExecution.
+ // TODO(toddw): Why only replace control edges from Placeholder?
+ //
+ // We must collect the edges first and process them in a second pass, since
+ // removing the edge from the graph invalidates feed_node->out_edges.
+ std::vector<const Edge*> feed_edges;
+ for (const Edge* edge : feed_node->out_edges()) {
+ if (edge->src_output() == id.output_index() ||
+ (edge->src_output() == Graph::kControlSlot &&
+ feed_node->type_string() == "Placeholder")) {
+ feed_edges.push_back(edge);
+ }
+ }
+ for (const Edge* edge : feed_edges) {
+ if (edge->src_output() == id.output_index()) {
+ graph->AddEdge(arg_node, 0, edge->dst(), edge->dst_input());
+ } else {
+ CHECK_EQ(edge->src_output(), Graph::kControlSlot);
+ graph->AddControlEdge(arg_node, edge->dst());
+ }
+ graph->RemoveEdge(edge);
+ }
+ }
+ return Status::OK();
+}
+
+// Each fetch id identifies the positional output of some node. For each fetch
+// node, adds a new _Retval node instead, and adds the node to `retval_nodes`.
+Status AddRetvalNodes(Graph* graph, const NodeMap& node_map,
+ const protobuf::RepeatedPtrField<Fetch>& fetches,
+ std::unordered_set<const Node*>* retval_nodes) {
+ for (int ret_index = 0; ret_index < fetches.size(); ++ret_index) {
+ const TensorId& id = fetches[ret_index].id();
+ auto it = node_map.find(id.node_name());
+ if (it == node_map.end()) {
+ return errors::NotFound("Can't find fetch id: ", TensorIdToString(id));
+ }
+ Node* fetch_node = it->second;
+ if (id.output_index() >= fetch_node->num_outputs()) {
+ return errors::InvalidArgument("Invalid fetch id: ", TensorIdToString(id),
+ ", output index should be < ",
+ fetch_node->num_outputs());
+ }
+ // Connects fetch_node -> retval_node.
+ Node* retval_node = nullptr;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(strings::StrCat("_retval_", ret_index), kRetvalOp)
+ .Input(fetch_node, id.output_index())
+ .Attr("T", BaseType(fetch_node->output_type(id.output_index())))
+ .Attr("index", ret_index)
+ .Attr(kFetchIdAttr, TensorIdToString(id))
+ .Finalize(graph, &retval_node));
+ retval_nodes->insert(retval_node);
+ }
+ return Status::OK();
+}
+
+// RewriteAndPruneGraph identifies input and output edges (named by the feed and
+// fetch ids respectively), and rewrites the edges so that inputs flow from _Arg
+// nodes, and outputs flow to _Retval nodes. This allows the symbolic graph
+// execution to know the input and output args for the generated function.
+Status RewriteAndPruneGraph(Graph* graph, const Config& config,
+ const MainFlags& flags) {
+ NodeMap node_map;
+ for (Node* n : graph->nodes()) {
+ node_map[n->name()] = n;
+ }
+ TF_RETURN_IF_ERROR(AddArgNodes(graph, node_map, config.feed()));
+ std::unordered_set<const Node*> retval_nodes;
+ TF_RETURN_IF_ERROR(
+ AddRetvalNodes(graph, node_map, config.fetch(), &retval_nodes));
+ TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_rewrite", *graph));
+ PruneForReverseReachability(graph, retval_nodes);
+ FixupSourceAndSinkEdges(graph);
+ TF_RETURN_IF_ERROR(DumpGraph(flags, "tfcompile_post_prune", *graph));
+ // Sanity-check, to make sure the feeds and fetches still exist post-pruning.
+ std::set<string> missing_feeds, missing_fetches;
+ for (const Feed& feed : config.feed()) {
+ missing_feeds.insert(TensorIdToString(feed.id()));
+ }
+ for (const Fetch& fetch : config.fetch()) {
+ missing_fetches.insert(TensorIdToString(fetch.id()));
+ }
+ for (const Node* n : graph->nodes()) {
+ if (n->type_string() == kArgOp) {
+ string feed_id;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFeedIdAttr, &feed_id));
+ if (missing_feeds.erase(feed_id) == 0) {
+ return errors::Aborted(kArgOp, " node found with unknown feed id: ",
+ feed_id);
+ }
+ } else if (n->type_string() == kRetvalOp) {
+ string fetch_id;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), kFetchIdAttr, &fetch_id));
+ if (missing_fetches.erase(fetch_id) == 0) {
+ return errors::Aborted(kRetvalOp, " node found with unknown fetch id: ",
+ fetch_id);
+ }
+ }
+ }
+ if (!missing_feeds.empty() || !missing_fetches.empty()) {
+ return errors::Aborted("Post graph-pruning", ", missing feeds: ",
+ str_util::Join(missing_feeds, ", "),
+ ", missing fetches: ",
+ str_util::Join(missing_fetches, ", "));
+ }
+ return Status::OK();
+}
+
+// CollectArgNodes collects _Arg nodes from the graph, and performs basic
+// sanity-checking to ensure the index and type attributes of each node are
+// initialized correctly.
+Status CollectArgNodes(const Graph& graph, std::vector<Node*>* arg_nodes) {
+ std::map<int, Node*> indexed_arg_nodes;
+ for (Node* n : graph.nodes()) {
+ if (n->type_string() == kArgOp) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
+ auto insert_result = indexed_arg_nodes.insert({index, n});
+ if (!insert_result.second) {
+ const Node* dup = insert_result.first->second;
+ return errors::InvalidArgument(
+ "Multiple ", kArgOp, " nodes with index ", index, ", ",
+ n->DebugString(), " and ", dup->DebugString());
+ }
+ }
+ }
+ arg_nodes->clear();
+ for (const auto& index_node : indexed_arg_nodes) {
+ if (index_node.first != arg_nodes->size()) {
+ return errors::InvalidArgument("Expected ", kArgOp, " node with index ",
+ arg_nodes->size(), ", but got index ",
+ index_node.first);
+ }
+ arg_nodes->push_back(index_node.second);
+ }
+ return Status::OK();
+}
+
+// Fills in xla_args from the corresponding _Arg nodes in the graph.
+Status CreateXlaArgs(const Graph& graph,
+ std::vector<XlaCompiler::Argument>* xla_args) {
+ std::vector<Node*> arg_nodes;
+ TF_RETURN_IF_ERROR(CollectArgNodes(graph, &arg_nodes));
+ for (const Node* node : arg_nodes) {
+ XlaCompiler::Argument arg;
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "T", &arg.type));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), "index", &arg.parameter));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kShapeAttr, &arg.shape));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->def(), kDebugNameAttr, &arg.name));
+ xla_args->push_back(arg);
+ }
+ return Status::OK();
+}
+
+// Converts the TensorFlow graph into an XLA computation, by executing the
+// graph symbolically, with each op building up the XLA HLO.
+Status ConvertGraphToXla(xla::LocalClient* client, std::unique_ptr<Graph> graph,
+ const FunctionLibraryDefinition* flib_def,
+ xla::Computation* computation, bool* has_context_arg) {
+ // Create a device and context to convert the graph into an XLA computation.
+ XlaOpRegistry::RegisterJitKernels();
+ // Populate the context with args from the graph.
+ for (Node* node : graph->nodes()) {
+ node->set_assigned_device_name(DEVICE_CPU_XLA_JIT);
+ }
+ std::vector<XlaCompiler::Argument> xla_args;
+ TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args));
+
+ // Compile the graph into an XLA computation.
+ XlaCompiler::Options compiler_options;
+ compiler_options.client = client;
+ compiler_options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
+ compiler_options.allow_cpu_custom_calls = true;
+ XlaCompiler compiler(compiler_options);
+
+ std::unique_ptr<FunctionLibraryRuntime> flib_run(NewFunctionLibraryRuntime(
+ compiler.device_mgr(), Env::Default(), compiler.device(),
+ graph->versions().producer(), flib_def, OptimizerOptions()));
+ XlaCompiler::CompilationResult result;
+ TF_RETURN_IF_ERROR(compiler.CompileGraph("tfcompile", std::move(graph),
+ flib_run.get(), xla_args,
+ false /* use_tuple_arg */, &result));
+ *has_context_arg = result.requires_runtime_context;
+ *computation = std::move(result.computation);
+
+ int num_const_results = 0;
+ for (int i = 0; i < result.outputs.size(); ++i) {
+ // Ending up with const results (i.e. output args) is an error, since it
+ // means that one or more fetches that the user specified will be dropped
+ // from the generated function. It's most likely a configuration error,
+ // since the user shouldn't be asking for output args that end up as consts.
+ //
+ // TODO(toddw): Provide a way for the user to access const output args,
+ // e.g. perhaps hard-coded into the header, or somehow copied into the
+ // output buffers.
+ if (result.outputs[i].is_constant) {
+ ++num_const_results;
+ LOG(ERROR) << "ConstRetVal index:" << i
+ << " value:" << result.outputs[i].constant_value.DebugString();
+ }
+ }
+ if (num_const_results > 0) {
+ return errors::Unimplemented(
+ "Conversion from TensorFlow graph to XLA resulted in ",
+ num_const_results,
+ " constant results. The configuration of "
+ "the output args (i.e. fetch ids) is probably wrong.");
+ }
+ if (computation->IsNull()) {
+ return errors::Aborted(
+ "Conversion from TensorFlow graph to XLA resulted in an empty "
+ "computation.");
+ }
+ return Status::OK();
+}
+
+// Compiles the XLA computation into executable code.
+Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
+ const xla::cpu::CpuAotCompilationOptions& aot_opts,
+ CompileResult* compile_result) {
+ // Retrieves arg and result layouts from the computation.
+ // TODO(toddw): Should we let the user choose the major/minor ordering?
+ xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or =
+ client->GetComputationShape(computation);
+ if (!pshape_or.ok()) {
+ return errors::Unknown("Couldn't get XLA program shape: ",
+ pshape_or.status().error_message());
+ }
+ compile_result->program_shape = *pshape_or.ValueOrDie();
+ xla::ProgramShape* pshape = &compile_result->program_shape;
+ std::vector<const xla::Shape*> arg_layouts;
+ for (int i = 0; i < pshape->parameters_size(); ++i) {
+ arg_layouts.push_back(pshape->mutable_parameters(i));
+ }
+ xla::StatusOr<std::unique_ptr<xla::AotCompilationResult>> aot_or =
+ client->CompileAheadOfTime(computation, arg_layouts, pshape->result(),
+ aot_opts);
+ if (!aot_or.ok()) {
+ return errors::Unknown("XLA compilation failed: ",
+ aot_or.status().error_message());
+ }
+ compile_result->aot =
+ xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
+ aot_or.ConsumeValueOrDie());
+ compile_result->entry_point = aot_opts.entry_point_name();
+ compile_result->pointer_size =
+ xla::LocalClient::PointerSizeForTriple(aot_opts.triple());
+ return Status::OK();
+}
+
+} // namespace
+
+Status InitGraph(const GraphDef& graph_def, const Config& config,
+ const MainFlags& flags, const FunctionLibraryDefinition* flib,
+ std::unique_ptr<Graph>* graph) {
+ TF_RETURN_IF_ERROR(ValidateConfig(config));
+ std::unique_ptr<Graph> g(new Graph(flib));
+ GraphDef copy_def(graph_def);
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&copy_def, *g->op_registry(),
+ 0 /*node_offset*/));
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(GraphConstructorOptions(), copy_def, g.get()));
+ TF_RETURN_IF_ERROR(RewriteAndPruneGraph(g.get(), config, flags));
+ *graph = std::move(g);
+ return Status::OK();
+}
+
+Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
+ const FunctionLibraryDefinition* flib,
+ CompileResult* compile_result) {
+ // Converts the graph into an XLA computation, and compiles the
+ // computation.
+ // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
+ namespace gpu = perftools::gputools;
+ gpu::Platform* cpu_platform =
+ gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
+ xla::LocalClient* client =
+ xla::ClientLibrary::GetOrCreateLocalClient(cpu_platform).ValueOrDie();
+ xla::Computation computation;
+ TF_RETURN_IF_ERROR(ConvertGraphToXla(client, std::move(graph), flib,
+ &computation,
+ &compile_result->has_context_arg));
+ if (!flags.debug_dir.empty()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::SessionModule> module,
+ computation.Snapshot());
+ string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb");
+ TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module));
+ }
+ xla::cpu::CpuAotCompilationOptions aot_opts(
+ flags.target_triple, flags.target_cpu, flags.target_features,
+ flags.entry_point,
+ xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic);
+ return CompileXla(client, computation, aot_opts, compile_result);
+}
+
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h
new file mode 100644
index 0000000000..8e9c64820b
--- /dev/null
+++ b/tensorflow/compiler/aot/compile.h
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_AOT_COMPILE_H_
+#define TENSORFLOW_COMPILER_AOT_COMPILE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/aot/flags.h"
+#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+// Constants for op types and attribute names.
+extern const char* const kArgOp;
+extern const char* const kRetvalOp;
+extern const char* const kFeedIdAttr;
+extern const char* const kFetchIdAttr;
+extern const char* const kShapeAttr;
+extern const char* const kDebugNameAttr;
+
+// InitGraph creates a graph based on the graph_def, that may then be compiled
+// by CompileGraph.
+//
+// The graph is rewritten with _Arg and _Retval nodes, representing the inputs
+// and outputs of the function that will be compiled. Each feed id causes a new
+// _Arg node to be created, where we first collect all existing edges pointing
+// from the named node's output index, and then rewrite them to point from that
+// _Arg node instead. Each fetch id causes a new _Retval node to be created,
+// with a new edge pointing from the named node's output index to that _Retval
+// node. All _Retval nodes also point to a special CompileExpressions node,
+// used internally to finish the compilation.
+//
+// The rewritten graph is then pruned to only contain the portion necessary to
+// compute the outputs. If dump_graphs is true, graph rewrites will be dumped
+// for debugging.
+Status InitGraph(const GraphDef& graph_def, const Config& config,
+ const MainFlags& flags, const FunctionLibraryDefinition* flib,
+ std::unique_ptr<Graph>* graph);
+
+// CompileResult describes the output of CompileGraph, where the object file
+// data and meta-information is available in aot.
+struct CompileResult {
+ // Contains object file and meta-info.
+ std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
+ xla::ProgramShape program_shape; // Static shape of args and results.
+ bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext?
+ string entry_point; // Name of generated function.
+ int pointer_size = 0; // Size of a pointer in bytes.
+};
+
+// CompileGraph compiles the graph into an object file containing a function
+// that performs the graph operations.
+//
+// The graph must have _Arg and _Retval nodes representing the function inputs
+// and outputs. Every _Arg node must have a shape attribute (key=kShapeAttr,
+// value=TensorShape) representing the static shape of that input, and every
+// _Retval node must point to a CompileExpressions node.
+//
+// Typically InitGraph is called to perform this initialization, followed by
+// full specification of the shape attributes.
+//
+// The XLA compilation options are specified in the flags.
+Status CompileGraph(std::unique_ptr<Graph> graph, const MainFlags& flags,
+ const FunctionLibraryDefinition* flib,
+ CompileResult* result);
+
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_COMPILE_H_
diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc
new file mode 100644
index 0000000000..4e3998b682
--- /dev/null
+++ b/tensorflow/compiler/aot/flags.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/flags.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags) {
+ const std::vector<Flag> tmp = {
+ {"graph", &flags->graph,
+ "Input GraphDef file. If the file ends in '.pbtxt' it is expected to "
+ "be in the human-readable proto text format, otherwise it is expected "
+ "to be in the proto binary format."},
+ {"config", &flags->config,
+ "Input file containing Config proto. If the file ends in '.pbtxt' it "
+ "is expected to be in the human-readable proto text format, otherwise "
+ "it is expected to be in the proto binary format."},
+ {"dump_fetch_nodes", &flags->dump_fetch_nodes,
+ "If set, only flags related to fetches are processed, and the resulting "
+ "fetch nodes will be dumped to stdout in a comma-separated list. "
+ "Typically used to format arguments for other tools, e.g. "
+ "freeze_graph."},
+ {"debug_dir", &flags->debug_dir,
+ "Specifies a directory to dump debugging information, including "
+ "rewritten graphs and the XLA HLO module."},
+ // Flags controlling the XLA ahead-of-time compilation, that correspond to
+ // the fields of xla::cpu::CpuAotCompilationOptions.
+ //
+ // TODO(toddw): The following flags also need to be supported:
+ // --xla_cpu_llvm_opt_level
+ // --xla_cpu_llvm_cl_opts
+ {"target_triple", &flags->target_triple,
+ "Target platform, similar to the clang -target flag. The general "
+ "format is <arch><sub>-<vendor>-<sys>-<abi>. "
+ "http://clang.llvm.org/docs/CrossCompilation.html#target-triple."},
+ {"target_cpu", &flags->target_cpu,
+ "Target cpu, similar to the clang -mcpu flag. "
+ "http://clang.llvm.org/docs/CrossCompilation.html#cpu-fpu-abi"},
+ {"target_features", &flags->target_features,
+ "Target features, e.g. +avx2, +neon, etc."},
+ {"entry_point", &flags->entry_point,
+ "Name of the generated function. If multiple generated object files "
+ "will be linked into the same binary, each will need a unique entry "
+ "point."},
+ {"cpp_class", &flags->cpp_class,
+ "Name of the generated C++ class, wrapping the generated function. The "
+ "syntax of this flag is [[<optional_namespace>::],...]<class_name>. "
+ "This mirrors the C++ syntax for referring to a class, where multiple "
+ "namespaces may precede the class name, separated by double-colons. "
+ "The class will be generated in the given namespace(s), or if no "
+ "namespaces are given, within the global namespace."},
+ {"out_object", &flags->out_object, "Output object file name."},
+ {"out_header", &flags->out_header, "Output header file name."},
+ };
+ flag_list->insert(flag_list->end(), tmp.begin(), tmp.end());
+}
+
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h
new file mode 100644
index 0000000000..e11a0173fa
--- /dev/null
+++ b/tensorflow/compiler/aot/flags.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_AOT_FLAGS_H_
+#define TENSORFLOW_COMPILER_AOT_FLAGS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+// Flags for the tfcompile binary. See *.cc file for descriptions.
+struct MainFlags {
+ string graph;
+ string config;
+ bool dump_fetch_nodes = false;
+ string debug_dir;
+ string target_triple;
+ string target_cpu;
+ string target_features;
+ string entry_point;
+ string cpp_class;
+ string out_object;
+ string out_header;
+};
+
+// Appends to flag_list a tensorflow::Flag for each field in MainFlags.
+void AppendMainFlags(std::vector<Flag>* flag_list, MainFlags* flags);
+
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_FLAGS_H_
diff --git a/tensorflow/compiler/aot/runtime.cc b/tensorflow/compiler/aot/runtime.cc
new file mode 100644
index 0000000000..208de5498d
--- /dev/null
+++ b/tensorflow/compiler/aot/runtime.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/runtime.h"
+
+#include <stdlib.h>
+
+#include "tensorflow/core/platform/dynamic_annotations.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace runtime {
+
+namespace {
+
+// Inline memory allocation routines here, because depending on '//base' brings
+// in libraries which use c++ streams, which adds considerable code size on
+// android.
+inline void* aligned_malloc(size_t size, int minimum_alignment) {
+#if defined(__ANDROID__) || defined(OS_ANDROID) || defined(OS_CYGWIN)
+ return memalign(minimum_alignment, size);
+#else // !__ANDROID__ && !OS_ANDROID && !OS_CYGWIN
+ void* ptr = nullptr;
+ // posix_memalign requires that the requested alignment be at least
+ // sizeof(void*). In this case, fall back on malloc which should return memory
+ // aligned to at least the size of a pointer.
+ const int required_alignment = sizeof(void*);
+ if (minimum_alignment < required_alignment) return malloc(size);
+ if (posix_memalign(&ptr, minimum_alignment, size) != 0)
+ return nullptr;
+ else
+ return ptr;
+#endif
+}
+
+inline void aligned_free(void* aligned_memory) { free(aligned_memory); }
+
+size_t align_to(size_t n, size_t align) {
+ return (((n - 1) / align) + 1) * align;
+}
+
+} // namespace
+
+size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n) {
+ size_t total = 0;
+ for (size_t i = 0; i < n; ++i) {
+ if (sizes[i] != -1) {
+ total += align_to(sizes[i], kAlign);
+ }
+ }
+ return total;
+}
+
+void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
+ bool annotate_initialized) {
+ const size_t total = aligned_buffer_bytes(sizes, n);
+ void* contiguous = nullptr;
+ if (total > 0) {
+ contiguous = aligned_malloc(total, kAlign);
+ if (annotate_initialized) {
+ // Since the memory for temp buffers is written to by JITed code, msan has
+ // no way of knowing the memory was initialized, so explicitly mark it.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(contiguous, total);
+ }
+ }
+ uintptr_t pos = reinterpret_cast<uintptr_t>(contiguous);
+ for (size_t i = 0; i < n; ++i) {
+ if (sizes[i] == -1) {
+ bufs[i] = nullptr;
+ } else {
+ bufs[i] = reinterpret_cast<void*>(pos);
+ pos += align_to(sizes[i], kAlign);
+ }
+ }
+ return contiguous;
+}
+
+void FreeContiguous(void* contiguous) {
+ if (contiguous != nullptr) {
+ aligned_free(contiguous);
+ }
+}
+
+} // namespace runtime
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/runtime.h b/tensorflow/compiler/aot/runtime.h
new file mode 100644
index 0000000000..d085864f00
--- /dev/null
+++ b/tensorflow/compiler/aot/runtime.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file contains utilities to make it easier to invoke functions generated
+// by tfcompile. Usage of these utilities is optional.
+
+#ifndef TENSORFLOW_COMPILER_AOT_RUNTIME_H_
+#define TENSORFLOW_COMPILER_AOT_RUNTIME_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace runtime {
+
+// Align to 32-bytes, to mimic tensorflow::Allocator::kAllocatorAlignment.
+static constexpr size_t kAlign = 32;
+
+// aligned_buffer_bytes returns the sum of each size in `sizes`, skipping -1
+// values. There are `n` entries in `sizes`. Each buffer is aligned to kAlign
+// byte boundaries.
+size_t aligned_buffer_bytes(const intptr_t* sizes, size_t n);
+
+// MallocContiguousBuffers allocates buffers for use by the entry point
+// generated by tfcompile. `sizes` is an array of byte sizes for each buffer,
+// where -1 causes the buffer pointer to be nullptr. There are `n` entries in
+// `sizes`. If `annotate_initialized` is set, the allocated memory will be
+// annotated as having been initialized - this is useful when allocating
+// temporary buffers.
+//
+// A single contiguous block of memory is allocated, and portions of it are
+// parceled out into `bufs`, which must have space for `n` entries. Returns the
+// head of the allocated contiguous block, which should be passed to
+// FreeContiguous when the buffers are no longer in use.
+void* MallocContiguousBuffers(const intptr_t* sizes, size_t n, void** bufs,
+ bool annotate_initialized);
+
+// FreeContiguous frees the contiguous block of memory allocated by
+// MallocContiguousBuffers.
+void FreeContiguous(void* contiguous);
+
+} // namespace runtime
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_RUNTIME_H_
diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc
new file mode 100644
index 0000000000..ac79c278c1
--- /dev/null
+++ b/tensorflow/compiler/aot/runtime_test.cc
@@ -0,0 +1,125 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/runtime.h"
+
+#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace runtime {
+namespace {
+
+TEST(Runtime, AlignmentValue) {
+ // We've chosen 32 byte alignment for the tfcompile runtime to mimic the
+ // regular tensorflow allocator, which was chosen to play nicely with Eigen.
+ // The tfcompile runtime also has a requirement that comes from the xla
+ // generated code, on the relation: buffer_size >= 16 ? 2 * sizeof(void*) : 8
+ // So any value that we choose must abide by that constraint as well.
+ EXPECT_EQ(kAlign, Allocator::kAllocatorAlignment);
+}
+
+TEST(Runtime, AlignedBufferBytes) {
+ EXPECT_EQ(aligned_buffer_bytes(nullptr, 0), 0);
+
+ static constexpr intptr_t sizesA[1] = {-1};
+ EXPECT_EQ(aligned_buffer_bytes(sizesA, 1), 0);
+
+ static constexpr intptr_t sizesB[1] = {3};
+ EXPECT_EQ(aligned_buffer_bytes(sizesB, 1), 32);
+
+ static constexpr intptr_t sizesC[1] = {32};
+ EXPECT_EQ(aligned_buffer_bytes(sizesC, 1), 32);
+
+ static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
+ EXPECT_EQ(aligned_buffer_bytes(sizesD, 7), 192);
+}
+
+void* add_ptr(void* base, uintptr_t delta) {
+ return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(base) + delta);
+}
+
+// To test MallocContiguousBuffers and FreeContiguous, we just check for
+// expected nullptrs, and write to each byte of allocated memory. We rely on
+// the leak checker to tell us if there's an inconsistency between malloc and
+// free. We also check the contiguous property.
+TEST(Runtime, MallocFreeContiguousBuffers) {
+ // Test empty sizes.
+ void* base = MallocContiguousBuffers(nullptr, 0, nullptr, false);
+ EXPECT_EQ(base, nullptr);
+ FreeContiguous(base);
+
+ // Test non-empty sizes with 0 sum.
+ static constexpr intptr_t sizesA[1] = {-1};
+ void* bufA[1];
+ base = MallocContiguousBuffers(sizesA, 1, bufA, false);
+ EXPECT_EQ(base, nullptr);
+ EXPECT_EQ(bufA[0], nullptr);
+ FreeContiguous(base);
+
+ // Test non-empty sizes with non-0 sum.
+ static constexpr intptr_t sizesB[1] = {3};
+ void* bufB[1];
+ base = MallocContiguousBuffers(sizesB, 1, bufB, false);
+ EXPECT_NE(base, nullptr);
+ EXPECT_EQ(bufB[0], add_ptr(base, 0));
+ char* bufB0_bytes = static_cast<char*>(bufB[0]);
+ bufB0_bytes[0] = 'A';
+ bufB0_bytes[1] = 'B';
+ bufB0_bytes[2] = 'C';
+ FreeContiguous(base);
+
+ // Test non-empty sizes with non-0 sum, and annotate_initialized.
+ static constexpr intptr_t sizesC[1] = {3};
+ void* bufC[1];
+ base = MallocContiguousBuffers(sizesC, 1, bufC, true);
+ EXPECT_NE(base, nullptr);
+ EXPECT_EQ(bufC[0], add_ptr(base, 0));
+ char* bufC0_bytes = static_cast<char*>(bufC[0]);
+ bufC0_bytes[0] = 'A';
+ bufC0_bytes[1] = 'B';
+ bufC0_bytes[2] = 'C';
+ FreeContiguous(base);
+
+ // Test mixed sizes.
+ static constexpr intptr_t sizesD[7] = {1, -1, 32, -1, 64, 2, 3};
+ void* bufD[7];
+ base = MallocContiguousBuffers(sizesD, 7, bufD, false);
+ EXPECT_NE(base, nullptr);
+ EXPECT_EQ(bufD[0], add_ptr(base, 0));
+ EXPECT_EQ(bufD[1], nullptr);
+ EXPECT_EQ(bufD[2], add_ptr(base, 32));
+ EXPECT_EQ(bufD[3], nullptr);
+ EXPECT_EQ(bufD[4], add_ptr(base, 64));
+ EXPECT_EQ(bufD[5], add_ptr(base, 128));
+ EXPECT_EQ(bufD[6], add_ptr(base, 160));
+ for (int i = 0; i < 7; ++i) {
+ const intptr_t size = sizesD[i];
+ if (size != -1) {
+ char* bufD_bytes = static_cast<char*>(bufD[i]);
+ for (size_t j = 0; j < size; ++j) {
+ bufD_bytes[j] = 'A' + j;
+ }
+ }
+ }
+ FreeContiguous(base);
+}
+
+} // namespace
+} // namespace runtime
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
new file mode 100644
index 0000000000..47ef5f82cb
--- /dev/null
+++ b/tensorflow/compiler/aot/test.cc
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Generated by the tf_library build rule. DO NOT EDIT!
+//
+// This file contains a test and benchmark for the function generated by
+// tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be rewritten to
+// real values before this file can be compiled.
+//
+// TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
+// TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
+// TFCOMPILE_NAME : Name for tests and benchmarks.
+//
+// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
+// generates a cc_test rule for you.
+
+// These macros must be defined before eigen files are included.
+#define EIGEN_USE_THREADS
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+
+// clang-format off
+#include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
+// clang-format on
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+// Macros that expand to tokens based on the entry point name.
+// clang-format off
+#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
+#define TEST_NAME {{TFCOMPILE_NAME}}Test // NOLINT(whitespace/braces)
+#define BM_NAME BM_{{TFCOMPILE_NAME}} // NOLINT(whitespace/braces)
+// clang-format on
+
+namespace tensorflow {
+namespace tfcompile {
+namespace {
+
+void zero_buffers(void** bufs, const intptr_t* sizes, size_t n) {
+ for (int i = 0; i < n; ++i) {
+ if (sizes[i] != -1) {
+ memset(bufs[i], 0, sizes[i]);
+ }
+ }
+}
+
+// Trivial test that runs the generated function to ensure it doesn't crash.
+TEST(TEST_NAME, NoCrash) {
+ Eigen::ThreadPool pool(port::NumSchedulableCPUs());
+ Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
+
+ CPP_CLASS computation;
+ computation.set_thread_pool(&device);
+ zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+
+ EXPECT_TRUE(computation.Run());
+}
+
+// Simple benchmark that repeatedly runs the generated function.
+void BM_NAME(int iters) {
+ testing::StopTiming();
+
+ Eigen::ThreadPool pool(port::NumSchedulableCPUs());
+ Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
+
+ CPP_CLASS computation;
+ computation.set_thread_pool(&device);
+ zero_buffers(computation.args(), CPP_CLASS::ArgSizes(), CPP_CLASS::kNumArgs);
+
+ testing::StartTiming();
+ while (--iters) {
+ computation.Run();
+ }
+ testing::StopTiming();
+}
+BENCHMARK(BM_NAME);
+
+} // namespace
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
new file mode 100644
index 0000000000..5625c0ab03
--- /dev/null
+++ b/tensorflow/compiler/aot/test_graph_tfadd.config.pbtxt
@@ -0,0 +1,16 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "x_y_sum" }
+}
diff --git a/tensorflow/compiler/aot/test_graph_tfadd.pbtxt b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt
new file mode 100644
index 0000000000..91c900e06d
--- /dev/null
+++ b/tensorflow/compiler/aot/test_graph_tfadd.pbtxt
@@ -0,0 +1,63 @@
+node {
+ name : "x_const"
+ op : "Const"
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 1
+ }
+ }
+ }
+ attr {
+ key : "dtype"
+ value {
+ type : DT_INT32
+ }
+ }
+}
+node {
+ name : "y_const"
+ op : "Const"
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ int_val: 2
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name : "x_y_sum"
+ op : "Add"
+ input : "x_const"
+ input : "y_const"
+ attr {
+ key : "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+versions {
+ producer: 15
+}
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
new file mode 100644
index 0000000000..ecb071a416
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -0,0 +1,146 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+
+test_suite(
+ name = "all_tests",
+ tags = ["manual"],
+ tests = [
+ ":test_graph_tfadd_test",
+ ":test_graph_tfadd_with_ckpt_saver_test",
+ ":test_graph_tfadd_with_ckpt_test",
+ ":test_graph_tfgather_test",
+ ":test_graph_tfmatmul_test",
+ ":test_graph_tfmatmulandadd_test",
+ ":tfcompile_test",
+ ],
+)
+
+py_binary(
+ name = "make_test_graphs",
+ testonly = 1,
+ srcs = ["make_test_graphs.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python", # TODO(b/34059704): remove when fixed
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+)
+
+genrule(
+ name = "gen_test_graphs",
+ testonly = 1,
+ outs = [
+ "test_graph_tfadd.pb",
+ "test_graph_tfadd_with_ckpt.pb",
+ "test_graph_tfadd_with_ckpt.ckpt",
+ "test_graph_tfadd_with_ckpt_saver.pb",
+ "test_graph_tfadd_with_ckpt_saver.ckpt",
+ "test_graph_tfadd_with_ckpt_saver.saver",
+ "test_graph_tfgather.pb",
+ "test_graph_tfmatmul.pb",
+ "test_graph_tfmatmulandadd.pb",
+ ],
+ cmd = "$(location :make_test_graphs) --out_dir $(@D)",
+ tags = ["manual"],
+ tools = [":make_test_graphs"],
+)
+
+tf_library(
+ name = "test_graph_tfadd",
+ testonly = 1,
+ config = "test_graph_tfadd.config.pbtxt",
+ cpp_class = "AddComp",
+ graph = "test_graph_tfadd.pb",
+ tags = ["manual"],
+)
+
+tf_library(
+ name = "test_graph_tfadd_with_ckpt",
+ testonly = 1,
+ config = "test_graph_tfadd_with_ckpt.config.pbtxt",
+ cpp_class = "AddWithCkptComp",
+ freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt",
+ graph = "test_graph_tfadd_with_ckpt.pb",
+ tags = ["manual"],
+)
+
+tf_library(
+ name = "test_graph_tfadd_with_ckpt_saver",
+ testonly = 1,
+ config = "test_graph_tfadd_with_ckpt.config.pbtxt",
+ cpp_class = "AddWithCkptSaverComp",
+ freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt",
+ freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver",
+ graph = "test_graph_tfadd_with_ckpt_saver.pb",
+ tags = ["manual"],
+)
+
+tf_library(
+ name = "test_graph_tfgather",
+ testonly = 1,
+ config = "test_graph_tfgather.config.pbtxt",
+ cpp_class = "GatherComp",
+ graph = "test_graph_tfgather.pb",
+ tags = ["manual"],
+)
+
+tf_library(
+ name = "test_graph_tfmatmul",
+ testonly = 1,
+ config = "test_graph_tfmatmul.config.pbtxt",
+ cpp_class = "foo::bar::MatMulComp",
+ graph = "test_graph_tfmatmul.pb",
+ tags = ["manual"],
+)
+
+tf_library(
+ name = "test_graph_tfmatmulandadd",
+ testonly = 1,
+ config = "test_graph_tfmatmulandadd.config.pbtxt",
+ cpp_class = "MatMulAndAddComp",
+ graph = "test_graph_tfmatmulandadd.pb",
+ tags = ["manual"],
+)
+
+cc_test(
+ name = "tfcompile_test",
+ srcs = ["tfcompile_test.cc"],
+ tags = ["manual"],
+ deps = [
+ ":test_graph_tfadd",
+ ":test_graph_tfadd_with_ckpt",
+ ":test_graph_tfadd_with_ckpt_saver",
+ ":test_graph_tfgather",
+ ":test_graph_tfmatmul",
+ ":test_graph_tfmatmulandadd",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//third_party/eigen3",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
new file mode 100644
index 0000000000..261dfcbdf8
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -0,0 +1,119 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Generate tensorflow graphs for testing tfcompile."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags as flags_lib
+from tensorflow.python.training import saver as saver_lib
+
+flags = flags_lib
+FLAGS = flags.FLAGS
+flags.DEFINE_string('out_dir', '',
+ 'Output directory for graphs, checkpoints and savers.')
+
+
+def tfadd():
+ x = constant_op.constant([1], name='x_const')
+ y = constant_op.constant([2], name='y_const')
+ math_ops.add(x, y, name='x_y_sum')
+
+
+def tfadd_with_ckpt():
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ math_ops.add(x, y, name='x_y_sum')
+
+ init_op = variables.initialize_all_variables()
+ saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
+ with session.Session() as sess:
+ sess.run(init_op)
+ sess.run(y.assign(y + 42))
+ # Without the checkpoint, the variable won't be set to 42.
+ ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir
+ saver.save(sess, ckpt)
+
+
+def tfadd_with_ckpt_saver():
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = variables.Variable(constant_op.constant([0]), name='y_saved')
+ math_ops.add(x, y, name='x_y_sum')
+
+ init_op = variables.initialize_all_variables()
+ saver = saver_lib.Saver(name='abcprefix', write_version=saver_pb2.SaverDef.V1)
+ with session.Session() as sess:
+ sess.run(init_op)
+ sess.run(y.assign(y + 42))
+ # Without the checkpoint, the variable won't be set to 42.
+ ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir
+ saver.save(sess, ckpt_file)
+ # Without the SaverDef, the restore op won't be named correctly.
+ saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir
+ with open(saver_file, 'w') as f:
+ f.write(saver.as_saver_def().SerializeToString())
+
+
+def tfgather():
+ params = array_ops.placeholder(dtypes.float32, name='params')
+ indices = array_ops.placeholder(dtypes.int32, name='indices')
+ array_ops.gather(params, indices, name='gather_output')
+
+
+def tfmatmul():
+ x = array_ops.placeholder(dtypes.float32, name='x_hold')
+ y = array_ops.placeholder(dtypes.float32, name='y_hold')
+ math_ops.matmul(x, y, name='x_y_prod')
+
+
+def tfmatmulandadd():
+ # This tests multiple outputs.
+ x = array_ops.placeholder(dtypes.float32, name='x_hold')
+ y = array_ops.placeholder(dtypes.float32, name='y_hold')
+ math_ops.matmul(x, y, name='x_y_prod')
+ math_ops.add(x, y, name='x_y_sum')
+
+
+def write_graph(build_graph):
+ """Build a graph using build_graph and write it out."""
+ g = ops.Graph()
+ with g.as_default():
+ build_graph()
+ filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__)
+ with open(filename, 'w') as f:
+ f.write(g.as_graph_def().SerializeToString())
+
+
+def main(_):
+ write_graph(tfadd)
+ write_graph(tfadd_with_ckpt)
+ write_graph(tfadd_with_ckpt_saver)
+ write_graph(tfgather)
+ write_graph(tfmatmul)
+ write_graph(tfmatmulandadd)
+
+
+if __name__ == '__main__':
+ app.run()
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
new file mode 100644
index 0000000000..5625c0ab03
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfadd.config.pbtxt
@@ -0,0 +1,16 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_const" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "x_y_sum" }
+}
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
new file mode 100644
index 0000000000..4d876a6e91
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.config.pbtxt
@@ -0,0 +1,10 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "x_y_sum" }
+}
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
new file mode 100644
index 0000000000..648ee31fdb
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfgather.config.pbtxt
@@ -0,0 +1,16 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "params" }
+ shape {
+ dim { size: 4 }
+ }
+}
+feed {
+ id { node_name: "indices" }
+ shape {
+ dim { size: 2 }
+ }
+}
+fetch {
+ id { node_name: "gather_output" }
+}
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
new file mode 100644
index 0000000000..a3ce2029c1
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmul.config.pbtxt
@@ -0,0 +1,18 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 2 }
+ dim { size: 3 }
+ }
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 3 }
+ dim { size: 2 }
+ }
+}
+fetch {
+ id { node_name: "x_y_prod" }
+}
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
new file mode 100644
index 0000000000..4a4a237a4f
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.config.pbtxt
@@ -0,0 +1,25 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 2 }
+ dim { size: 2 }
+ }
+ name: "x"
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 2 }
+ dim { size: 2 }
+ }
+ name: "y"
+}
+fetch {
+ id { node_name: "x_y_prod" }
+ name: "x_y_prod"
+}
+fetch {
+ id { node_name: "x_y_sum" }
+ name: "x_y_sum"
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
new file mode 100644
index 0000000000..f57d2859df
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -0,0 +1,381 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace {
+
+TEST(TFCompileTest, Add) {
+ AddComp add;
+ EXPECT_EQ(add.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add.arg1_data(), add.args()[1]);
+
+ add.arg0() = 1;
+ add.arg1() = 2;
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 3);
+ EXPECT_EQ(add.result0_data()[0], 3);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+
+ add.arg0_data()[0] = 123;
+ add.arg1_data()[0] = 456;
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 579);
+ EXPECT_EQ(add.result0_data()[0], 579);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+
+ const AddComp& add_const = add;
+ EXPECT_EQ(add_const.error_msg(), "");
+ EXPECT_EQ(add_const.arg0(), 123);
+ EXPECT_EQ(add_const.arg0_data()[0], 123);
+ EXPECT_EQ(add_const.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add_const.arg1(), 456);
+ EXPECT_EQ(add_const.arg1_data()[0], 456);
+ EXPECT_EQ(add_const.arg1_data(), add.args()[1]);
+ EXPECT_EQ(add_const.result0(), 579);
+ EXPECT_EQ(add_const.result0_data()[0], 579);
+ EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
+}
+
+// Run tests that use set_argN_data separately, to avoid accidentally re-using
+// non-existent buffers.
+TEST(TFCompileTest, Add_SetArg) {
+ AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
+
+ int32 arg_x = 10;
+ int32 arg_y = 32;
+ add.set_arg0_data(&arg_x);
+ add.set_arg1_data(&arg_y);
+ EXPECT_EQ(add.arg0_data(), add.args()[0]);
+ EXPECT_EQ(add.arg1_data(), add.args()[1]);
+
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 42);
+ EXPECT_EQ(add.result0_data()[0], 42);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+}
+
+TEST(TFCompileTest, AddWithCkpt) {
+ AddWithCkptComp add;
+ EXPECT_EQ(add.arg0_data(), add.args()[0]);
+
+ add.arg0() = 1;
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 43);
+ EXPECT_EQ(add.result0_data()[0], 43);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+
+ add.arg0_data()[0] = 111;
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 153);
+ EXPECT_EQ(add.result0_data()[0], 153);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+
+ const AddWithCkptComp& add_const = add;
+ EXPECT_EQ(add_const.error_msg(), "");
+ EXPECT_EQ(add_const.arg0(), 111);
+ EXPECT_EQ(add_const.arg0_data()[0], 111);
+ EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
+ EXPECT_EQ(add_const.result0(), 153);
+ EXPECT_EQ(add_const.result0_data()[0], 153);
+ EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
+}
+
+TEST(TFCompileTest, AddWithCkptSaver) {
+ AddWithCkptSaverComp add;
+ EXPECT_EQ(add.arg0_data(), add.args()[0]);
+
+ add.arg0() = 1;
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 43);
+ EXPECT_EQ(add.result0_data()[0], 43);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+
+ add.arg0_data()[0] = 111;
+ EXPECT_TRUE(add.Run());
+ EXPECT_EQ(add.error_msg(), "");
+ EXPECT_EQ(add.result0(), 153);
+ EXPECT_EQ(add.result0_data()[0], 153);
+ EXPECT_EQ(add.result0_data(), add.results()[0]);
+
+ const AddWithCkptSaverComp& add_const = add;
+ EXPECT_EQ(add_const.error_msg(), "");
+ EXPECT_EQ(add_const.arg0(), 111);
+ EXPECT_EQ(add_const.arg0_data()[0], 111);
+ EXPECT_EQ(add_const.arg0_data(), add_const.args()[0]);
+ EXPECT_EQ(add_const.result0(), 153);
+ EXPECT_EQ(add_const.result0_data()[0], 153);
+ EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
+}
+
+TEST(TFCompileTest, Gather) {
+ GatherComp gather;
+ EXPECT_EQ(gather.arg0_data(), gather.args()[0]);
+ EXPECT_EQ(gather.arg1_data(), gather.args()[1]);
+
+ // Successful gather.
+ {
+ const float params[4] = {1, 2, 3, 4};
+ std::copy(params + 0, params + 4, gather.arg0_data());
+ const int32 indices[2] = {1, 3};
+ std::copy(indices + 0, indices + 2, gather.arg1_data());
+ EXPECT_TRUE(gather.Run());
+ EXPECT_EQ(gather.error_msg(), "");
+ const float results[2] = {2, 4};
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ(gather.result0(i), results[i]);
+ EXPECT_EQ(gather.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(gather.result0_data(), gather.results()[0]);
+
+ const GatherComp& gather_const = gather;
+ EXPECT_EQ(gather_const.error_msg(), "");
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(gather_const.arg0(i), params[i]);
+ EXPECT_EQ(gather_const.arg0_data()[i], params[i]);
+ }
+ EXPECT_EQ(gather_const.arg0_data(), gather_const.args()[0]);
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ(gather_const.arg1(i), indices[i]);
+ EXPECT_EQ(gather_const.arg1_data()[i], indices[i]);
+ }
+ EXPECT_EQ(gather_const.arg1_data(), gather_const.args()[1]);
+ for (int i = 0; i < 2; ++i) {
+ EXPECT_EQ(gather_const.result0(i), results[i]);
+ EXPECT_EQ(gather_const.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(gather_const.result0_data(), gather.results()[0]);
+ }
+
+ // Bad indices returns an error.
+ {
+ const float params[4] = {1, 2, 3, 4};
+ std::copy(params + 0, params + 4, gather.arg0_data());
+ const int32 indices[2] = {1, 4};
+ std::copy(indices + 0, indices + 2, gather.arg1_data());
+ EXPECT_FALSE(gather.Run());
+ EXPECT_EQ(gather.error_msg(), "Invalid index for gather");
+ }
+}
+
+TEST(TFCompileTest, MatMul2) {
+ Eigen::ThreadPool tp(2);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ foo::bar::MatMulComp matmul;
+ matmul.set_thread_pool(&device);
+ EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
+ EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
+
+ // Test using the argN() methods.
+ {
+ matmul.arg0(0, 0) = 1;
+ matmul.arg0(0, 1) = 2;
+ matmul.arg0(0, 2) = 3;
+ matmul.arg0(1, 0) = 4;
+ matmul.arg0(1, 1) = 5;
+ matmul.arg0(1, 2) = 6;
+
+ matmul.arg1(0, 0) = 7;
+ matmul.arg1(0, 1) = 8;
+ matmul.arg1(1, 0) = 9;
+ matmul.arg1(1, 1) = 10;
+ matmul.arg1(2, 0) = 11;
+ matmul.arg1(2, 1) = 12;
+
+ EXPECT_TRUE(matmul.Run());
+ EXPECT_EQ(matmul.error_msg(), "");
+ const float results[4] = {58, 64, 139, 154};
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
+ EXPECT_EQ(matmul.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
+ }
+
+ // Test using the argN_data() methods.
+ {
+ const float args[12] = {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120};
+ std::copy(args + 0, args + 6, matmul.arg0_data());
+ std::copy(args + 6, args + 12, matmul.arg1_data());
+ EXPECT_TRUE(matmul.Run());
+ EXPECT_EQ(matmul.error_msg(), "");
+ const float results[4] = {5800, 6400, 13900, 15400};
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
+ EXPECT_EQ(matmul.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
+
+ const foo::bar::MatMulComp& matmul_const = matmul;
+ EXPECT_EQ(matmul_const.error_msg(), "");
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(matmul_const.arg0(i / 3, i % 3), args[i]);
+ EXPECT_EQ(matmul_const.arg0_data()[i], args[i]);
+ }
+ EXPECT_EQ(matmul_const.arg0_data(), matmul.args()[0]);
+ for (int i = 0; i < 6; ++i) {
+ EXPECT_EQ(matmul_const.arg1(i / 2, i % 2), args[i + 6]);
+ EXPECT_EQ(matmul_const.arg1_data()[i], args[i + 6]);
+ }
+ EXPECT_EQ(matmul_const.arg1_data(), matmul.args()[1]);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]);
+ EXPECT_EQ(matmul_const.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]);
+ }
+}
+
+// Run tests that use set_argN_data separately, to avoid accidentally re-using
+// non-existent buffers.
+TEST(TFCompileTest, MatMul2_SetArg) {
+ Eigen::ThreadPool tp(2);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ foo::bar::MatMulComp matmul(
+ foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
+ matmul.set_thread_pool(&device);
+
+ // Test using the set_argN_data() methods.
+ float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}};
+ float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}};
+ matmul.set_arg0_data(&arg0);
+ matmul.set_arg1_data(&arg1);
+ EXPECT_EQ(matmul.arg0_data(), matmul.args()[0]);
+ EXPECT_EQ(matmul.arg1_data(), matmul.args()[1]);
+
+ EXPECT_TRUE(matmul.Run());
+ EXPECT_EQ(matmul.error_msg(), "");
+ const float results[4] = {58, 64, 139, 154};
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]);
+ EXPECT_EQ(matmul.result0_data()[i], results[i]);
+ }
+ EXPECT_EQ(matmul.result0_data(), matmul.results()[0]);
+}
+
+TEST(TFCompileTest, MatMulAndAdd1) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ MatMulAndAddComp muladd;
+ muladd.set_thread_pool(&device);
+ EXPECT_EQ(muladd.arg0_data(), muladd.args()[0]);
+ EXPECT_EQ(muladd.arg1_data(), muladd.args()[1]);
+
+ // Test methods with positional args and results.
+ {
+ const float args[8] = {1, 2, 3, 4, 5, 6, 7, 8};
+ std::copy(args + 0, args + 4, muladd.arg0_data());
+ std::copy(args + 4, args + 8, muladd.arg1_data());
+ EXPECT_TRUE(muladd.Run());
+ EXPECT_EQ(muladd.error_msg(), "");
+ const float results0[4] = {19, 22, 43, 50};
+ const float results1[4] = {6, 8, 10, 12};
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd.result0(i / 2, i % 2), results0[i]);
+ EXPECT_EQ(muladd.result0_data()[i], results0[i]);
+ EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]);
+ EXPECT_EQ(muladd.result1_data()[i], results1[i]);
+ }
+ EXPECT_EQ(muladd.result0_data(), muladd.results()[0]);
+ EXPECT_EQ(muladd.result1_data(), muladd.results()[1]);
+
+ const MatMulAndAddComp& muladd_const = muladd;
+ EXPECT_EQ(muladd_const.error_msg(), "");
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd_const.arg0(i / 2, i % 2), args[i]);
+ EXPECT_EQ(muladd_const.arg0_data()[i], args[i]);
+ }
+ EXPECT_EQ(muladd_const.arg0_data(), muladd.args()[0]);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd_const.arg1(i / 2, i % 2), args[i + 4]);
+ EXPECT_EQ(muladd_const.arg1_data()[i], args[i + 4]);
+ }
+ EXPECT_EQ(muladd_const.arg1_data(), muladd.args()[1]);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd_const.result0(i / 2, i % 2), results0[i]);
+ EXPECT_EQ(muladd_const.result0_data()[i], results0[i]);
+ EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]);
+ EXPECT_EQ(muladd_const.result1_data()[i], results1[i]);
+ }
+ EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]);
+ EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]);
+ }
+
+ // Test methods with named args and results.
+ {
+ const float args[8] = {10, 20, 30, 40, 50, 60, 70, 80};
+ std::copy(args + 0, args + 4, muladd.arg_x_data());
+ std::copy(args + 4, args + 8, muladd.arg_y_data());
+ EXPECT_TRUE(muladd.Run());
+ EXPECT_EQ(muladd.error_msg(), "");
+ const float results0[4] = {1900, 2200, 4300, 5000};
+ const float results1[4] = {60, 80, 100, 120};
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd.result_x_y_prod(i / 2, i % 2), results0[i]);
+ EXPECT_EQ(muladd.result_x_y_prod_data()[i], results0[i]);
+ EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]);
+ EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]);
+ }
+ EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]);
+ EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]);
+
+ // Test const methods.
+ const MatMulAndAddComp& muladd_const = muladd;
+ EXPECT_EQ(muladd_const.error_msg(), "");
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd_const.arg_x(i / 2, i % 2), args[i]);
+ EXPECT_EQ(muladd_const.arg_x_data()[i], args[i]);
+ }
+ EXPECT_EQ(muladd_const.arg_x_data(), muladd.args()[0]);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd_const.arg_y(i / 2, i % 2), args[i + 4]);
+ EXPECT_EQ(muladd_const.arg_y_data()[i], args[i + 4]);
+ }
+ EXPECT_EQ(muladd_const.arg_y_data(), muladd.args()[1]);
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(muladd_const.result_x_y_prod(i / 2, i % 2), results0[i]);
+ EXPECT_EQ(muladd_const.result_x_y_prod_data()[i], results0[i]);
+ EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]);
+ EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]);
+ }
+ EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]);
+ EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]);
+ }
+}
+
+} // namespace
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
new file mode 100644
index 0000000000..6f2e0958fd
--- /dev/null
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -0,0 +1,285 @@
+# -*- Python -*-
+
+"""Build macro that compiles a TensorFlow graph into a cc_library.
+
+To use from your BUILD file, add the following line to load the macro:
+
+load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+
+Then call the macro like this:
+
+tf_library(
+ name = "test_graph_tfmatmul",
+ config = "test_graph_tfmatmul.config.pbtxt",
+ cpp_class = "MatMulComp",
+ graph = ":test_graph_tfmatmul.pb",
+)
+"""
+
+load("//tensorflow:tensorflow.bzl", "if_android", "tf_copts")
+
+def tf_library(name, graph, config,
+ freeze_checkpoint=None, freeze_saver=None,
+ cpp_class=None, gen_test=True, gen_benchmark=True,
+ visibility=None, testonly=None,
+ tfcompile_flags=None,
+ tfcompile_tool="//tensorflow/compiler/aot:tfcompile",
+ deps=None, tags=None):
+ """Runs tfcompile to compile a TensorFlow graph into executable code.
+
+ Args:
+ name: The name of the build rule.
+ graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' it
+ is expected to be in the human-readable proto text format, otherwise it is
+ expected to be in the proto binary format.
+ config: File containing tensorflow.tfcompile.Config proto. If the file ends
+ in '.pbtxt' it is expected to be in the human-readable proto text format,
+ otherwise it is expected to be in the proto binary format.
+ freeze_checkpoint: If provided, run freeze_graph with this checkpoint to
+ convert variables into constants.
+ freeze_saver: If provided, run freeze_graph with this saver, in SaverDef
+ binary form, to convert variables into constants.
+ cpp_class: The name of the generated C++ class, wrapping the generated
+ function. The syntax of this flag is
+ [[<optional_namespace>::],...]<class_name>. This mirrors the C++ syntax
+ for referring to a class, where multiple namespaces may precede the class
+ name, separated by double-colons. The class will be generated in the
+ given namespace(s), or if no namespaces are given, within the global
+ namespace.
+ gen_test: If True, also generate a cc_test rule that builds a simple
+ test and benchmark.
+ gen_benchmark: If True, also generate a binary with a simple benchmark.
+ Unlike the output of gen_test, this benchmark can be run on android.
+ visibility: Bazel build visibility.
+ testonly: Bazel testonly attribute.
+ tfcompile_flags: Extra flags to pass to tfcompile to control compilation.
+ tfcompile_tool: The tfcompile binary. A non-default can be passed to
+ use a tfcompile built with extra dependencies.
+ deps: a list of extra deps to include on the build rules for
+ the generated library.
+ tags: tags to apply to subsidiary build rules.
+
+ The output header is called <name>.h.
+ """
+ if not cpp_class:
+ fail("cpp_class must be specified")
+
+ tfcompile_graph = graph
+ if freeze_checkpoint or freeze_saver:
+ if not freeze_checkpoint:
+ fail("freeze_checkpoint must be specified when freeze_saver is specified")
+
+ freeze_name = "freeze_" + name
+ freeze_file = freeze_name + ".pb"
+
+ # First run tfcompile to generate the list of out_nodes.
+ out_nodes_file = "out_nodes_" + freeze_name
+ native.genrule(
+ name=("gen_" + out_nodes_file),
+ srcs=[config],
+ outs=[out_nodes_file],
+ cmd=("$(location " + tfcompile_tool + ")" +
+ " --config=$(location " + config + ")" +
+ " --dump_fetch_nodes > $@"),
+ tools=[tfcompile_tool],
+ # Run tfcompile on the build host, rather than forge, since it's
+ # typically way faster on the local machine.
+ local=1,
+ tags=tags,
+ )
+
+ # Now run freeze_graph to convert variables into constants.
+ freeze_args = (" --input_graph=$(location " + graph + ")" +
+ " --input_binary=" + str(not graph.endswith(".pbtxt")) +
+ " --input_checkpoint=$(location " + freeze_checkpoint + ")" +
+ " --output_graph=$(location " + freeze_file + ")" +
+ " --output_node_names=$$(<$(location " + out_nodes_file +
+ "))")
+ freeze_saver_srcs = []
+ if freeze_saver:
+ freeze_args += " --input_saver=$(location " + freeze_saver + ")"
+ freeze_saver_srcs += [freeze_saver]
+ native.genrule(
+ name=freeze_name,
+ srcs=[
+ graph,
+ freeze_checkpoint,
+ out_nodes_file,
+ ] + freeze_saver_srcs,
+ outs=[freeze_file],
+ cmd=("$(location //tensorflow/python/tools:freeze_graph)" +
+ freeze_args),
+ tools=["//tensorflow/python/tools:freeze_graph"],
+ tags=tags,
+ )
+ tfcompile_graph = freeze_file
+
+ # Rule that runs tfcompile to produce the header and object file.
+ header_file = name + ".h"
+ object_file = name + ".o"
+ ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_")
+ native.genrule(
+ name=("gen_" + name),
+ srcs=[
+ tfcompile_graph,
+ config,
+ ],
+ outs=[
+ header_file,
+ object_file,
+ ],
+ cmd=("$(location " + tfcompile_tool + ")" +
+ " --graph=$(location " + tfcompile_graph + ")" +
+ " --config=$(location " + config + ")" +
+ " --entry_point=" + ep +
+ " --cpp_class=" + cpp_class +
+ " --target_triple=" + target_llvm_triple() +
+ " --out_header=$(@D)/" + header_file +
+ " --out_object=$(@D)/" + object_file +
+ " " + (tfcompile_flags or "")),
+ tools=[tfcompile_tool],
+ visibility=visibility,
+ testonly=testonly,
+ # Run tfcompile on the build host since it's typically faster on the local
+ # machine.
+ #
+ # Note that setting the local=1 attribute on a *test target* causes the
+ # test infrastructure to skip that test. However this is a genrule, not a
+ # test target, and runs with --genrule_strategy=forced_forge, meaning the
+ # local=1 attribute is ignored, and the genrule is still run.
+ #
+ # https://www.bazel.io/versions/master/docs/be/general.html#genrule
+ local=1,
+ tags=tags,
+ )
+
+ # The cc_library rule packaging up the header and object file, and needed
+ # kernel implementations.
+ native.cc_library(
+ name=name,
+ srcs=[object_file],
+ hdrs=[header_file],
+ visibility=visibility,
+ testonly=testonly,
+ deps = [
+ # TODO(cwhipkey): only depend on kernel code that the model actually needed.
+ "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32",
+ "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64",
+ "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
+ "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
+ "//tensorflow/compiler/aot:runtime",
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_matmul",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//third_party/eigen3",
+ "//tensorflow/core:framework_lite",
+ ] + (deps or []),
+ tags=tags,
+ )
+
+ # Variables used for gen_test and gen_benchmark.
+ no_ns_name = ""
+ cpp_class_split = cpp_class.rsplit("::", maxsplit=2)
+ if len(cpp_class_split) == 1:
+ no_ns_name = cpp_class_split[0]
+ else:
+ no_ns_name = cpp_class_split[1]
+ sed_replace = (
+ "-e \"s|{{TFCOMPILE_HEADER}}|$(location " + header_file + ")|g\" " +
+ "-e \"s|{{TFCOMPILE_CPP_CLASS}}|" + cpp_class + "|g\" " +
+ "-e \"s|{{TFCOMPILE_NAME}}|" + no_ns_name + "|g\" ")
+
+ if gen_test:
+ test_name = name + "_test"
+ test_file = test_name + ".cc"
+ # Rule to rewrite test.cc to produce the test_file.
+ native.genrule(
+ name=("gen_" + test_name),
+ testonly=1,
+ srcs=[
+ "//tensorflow/compiler/aot:test.cc",
+ header_file,
+ ],
+ outs=[test_file],
+ cmd=("sed " + sed_replace +
+ " $(location //tensorflow/compiler/aot:test.cc) " +
+ "> $(OUTS)"),
+ tags=tags,
+ )
+
+ # The cc_test rule for the generated code.
+ native.cc_test(
+ name=test_name,
+ srcs=[test_file],
+ deps=[
+ ":" + name,
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/compiler/aot:runtime",
+ "//tensorflow/compiler/aot:tf_library_test_main",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//third_party/eigen3",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+ tags=tags,
+ )
+
+ if gen_benchmark:
+ benchmark_name = name + "_benchmark"
+ benchmark_file = benchmark_name + ".cc"
+ benchmark_main = ("//tensorflow/compiler/aot:" +
+ "benchmark_main.template")
+
+ # Rule to rewrite benchmark.cc to produce the benchmark_file.
+ native.genrule(
+ name=("gen_" + benchmark_name),
+ srcs=[
+ benchmark_main,
+ header_file,
+ ],
+ testonly = testonly,
+ outs=[benchmark_file],
+ cmd=("sed " + sed_replace +
+ " $(location " + benchmark_main + ") " +
+ "> $(OUTS)"),
+ tags=tags,
+ )
+
+ # The cc_benchmark rule for the generated code.
+ #
+ # Note: to get smaller size on android for comparison, compile with:
+ # --copt=-fvisibility=hidden
+ # --copt=-D_LIBCPP_TYPE_VIS=_LIBCPP_HIDDEN
+ # --copt=-D_LIBCPP_EXCEPTION_ABI=_LIBCPP_HIDDEN
+ native.cc_binary(
+ name=benchmark_name,
+ srcs=[benchmark_file],
+ testonly = testonly,
+ copts = tf_copts(),
+ linkopts = if_android(["-pie", "-s"]),
+ deps=[
+ ":" + name,
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/compiler/aot:benchmark",
+ "//tensorflow/compiler/aot:runtime",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//third_party/eigen3",
+ ] + if_android([
+ "//tensorflow/compiler/aot:benchmark_extra_android",
+ ]),
+ tags=tags,
+ )
+
+
+def target_llvm_triple():
+ """Returns the target LLVM triple to be used for compiling the target."""
+ # TODO(toddw): Add target_triple for other targets. For details see:
+ # http://llvm.org/docs/doxygen/html/Triple_8h_source.html
+ return select({
+ "//tensorflow:android_arm": "armv7-none-android",
+ "//tensorflow:android_arm64": "aarch64-none-android",
+ "//conditions:default": "x86_64-pc-linux",
+ })
diff --git a/tensorflow/compiler/aot/tfcompile.proto b/tensorflow/compiler/aot/tfcompile.proto
new file mode 100644
index 0000000000..be3f504350
--- /dev/null
+++ b/tensorflow/compiler/aot/tfcompile.proto
@@ -0,0 +1,43 @@
+syntax = "proto3";
+
+package tensorflow.tfcompile;
+option cc_enable_arenas = true;
+option java_outer_classname = "CompileProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.tfcompile";
+
+import "tensorflow/core/framework/tensor_shape.proto";
+
+// TensorId identifies a tensor in a TensorFlow graph, by specifying the output
+// index of a particular node in the graph. If the output of the named node
+// feeds into other node(s), this corresponds to one or more edges. Otherwise
+// it doesn't correspond to any existing edges at all, e.g. for output nodes.
+message TensorId {
+ string node_name = 1;
+ int64 output_index = 2;
+};
+
+// Feed represents a single feed tensor in the graph, which corresponds to an
+// input argument for the generated function.
+message Feed {
+ TensorId id = 1;
+ TensorShapeProto shape = 2;
+ string name = 3; // Optional name for generated code.
+};
+
+// Fetch represents a single fetch tensor in the graph, which corresponds to an
+// output argument for the generated function.
+message Fetch {
+ TensorId id = 1;
+ string name = 2; // Optional name for generated code.
+};
+
+// Config represents configuration information for tfcompile.
+message Config {
+ // Each feed is a positional input argument for the generated function. The
+ // order of each entry matches the order of each input argument.
+ repeated Feed feed = 1;
+ // Each fetch is a positional output argument for the generated function. The
+ // order of each entry matches the order of each output argument.
+ repeated Fetch fetch = 2;
+};
diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc
new file mode 100644
index 0000000000..85ef9560bb
--- /dev/null
+++ b/tensorflow/compiler/aot/tfcompile_main.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/aot/codegen.h"
+#include "tensorflow/compiler/aot/compile.h"
+#include "tensorflow/compiler/aot/flags.h"
+#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/compiler/aot/tfcompile_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+const char kUsageHeader[] =
+ "tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n"
+ "resulting in an object file compiled for your target architecture, and a\n"
+ "header file that gives access to the functionality in the object file.\n"
+ "A typical invocation looks like this:\n"
+ "\n"
+ " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt\n"
+ "\n";
+
+Status ReadProtoFile(const string& kind, const string& fname,
+ protobuf::Message* proto) {
+ if (StringPiece(fname).ends_with(".pbtxt")) {
+ return ReadTextProto(Env::Default(), fname, proto);
+ } else {
+ return ReadBinaryProto(Env::Default(), fname, proto);
+ }
+}
+
+void ParseTensorId(const string& name, TensorId* id) {
+ const std::pair<StringPiece, int> name_index = ParseTensorName(name);
+ id->set_node_name(name_index.first.ToString());
+ id->set_output_index(name_index.second);
+}
+
+Status Main(const MainFlags& flags) {
+ // Process config.
+ Config config;
+ TF_RETURN_IF_ERROR(ReadProtoFile("config", flags.config, &config));
+ TF_RETURN_IF_ERROR(ValidateConfig(config));
+ if (flags.dump_fetch_nodes) {
+ std::set<string> nodes;
+ for (const Fetch& fetch : config.fetch()) {
+ nodes.insert(fetch.id().node_name());
+ }
+ std::cout << str_util::Join(nodes, ",");
+ return Status::OK();
+ }
+
+ // Read and initialize the graph.
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(ReadProtoFile("graph", flags.graph, &graph_def));
+ std::unique_ptr<Graph> graph;
+ FunctionLibraryDefinition flib(OpRegistry::Global(), graph_def.library());
+ TF_RETURN_IF_ERROR(InitGraph(graph_def, config, flags, &flib, &graph));
+
+ CompileResult compile_result;
+ TF_RETURN_IF_ERROR(
+ CompileGraph(std::move(graph), flags, &flib, &compile_result));
+
+ // Write output files.
+ Env* env = Env::Default();
+ const std::vector<char>& obj = compile_result.aot->object_file_data();
+ TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object,
+ StringPiece(obj.data(), obj.size())));
+ HeaderOpts header_opts;
+ TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name,
+ &header_opts.namespaces));
+ string header;
+ TF_RETURN_IF_ERROR(
+ GenerateHeader(header_opts, config, compile_result, &header));
+ TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header));
+ return Status::OK();
+}
+
+} // end namespace tfcompile
+} // end namespace tensorflow
+
+int main(int argc, char** argv) {
+ tensorflow::tfcompile::MainFlags flags;
+ flags.target_triple = "x86_64-pc-linux";
+ flags.out_object = "out.o";
+ flags.out_header = "out.h";
+
+ std::vector<tensorflow::Flag> flag_list;
+ AppendMainFlags(&flag_list, &flags);
+ xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list);
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
+
+ tensorflow::string usage = tensorflow::tfcompile::kUsageHeader;
+ usage += tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ QCHECK(parsed_flags_ok) << "\n" << usage;
+
+ tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
+ QCHECK(argc == 1 && !flags.config.empty() &&
+ (flags.dump_fetch_nodes ||
+ (!flags.graph.empty() && !flags.entry_point.empty())))
+ << "\n"
+ << usage;
+
+ TF_QCHECK_OK(tensorflow::tfcompile::Main(flags));
+ return 0;
+}
diff --git a/tensorflow/compiler/aot/tfcompile_util.cc b/tensorflow/compiler/aot/tfcompile_util.cc
new file mode 100644
index 0000000000..fd073a2e26
--- /dev/null
+++ b/tensorflow/compiler/aot/tfcompile_util.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/tfcompile_util.h"
+
+#include <set>
+
+#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+namespace {
+
+bool IsAlpha(char c) {
+ return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
+}
+
+bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
+
+Status ValidateTensorId(const TensorId& id) {
+ if (id.node_name().empty()) {
+ return errors::InvalidArgument("TensorId node_name must be non-empty");
+ }
+ if (id.output_index() < 0) {
+ return errors::InvalidArgument("TensorId output_index must be positive");
+ }
+ return Status::OK();
+}
+
+Status ValidateFeedFetchName(const string& kind, const string& name,
+ std::set<string>* names) {
+ if (!name.empty()) {
+ TF_RETURN_IF_ERROR(ValidateCppIdent(name, kind + " name"));
+ if (!names->insert(name).second) {
+ return errors::InvalidArgument("duplicate ", kind, " name: ", name);
+ }
+ }
+ return Status::OK();
+}
+
+Status CheckFeedFetchNameConflicts(const string& kind,
+ const std::set<string>& names) {
+ // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
+ // since that will cause a collision in codegen symbols.
+ for (const string& name : names) {
+ const string name_data(name + "_data");
+ if (names.find(name_data) != names.end()) {
+ return errors::InvalidArgument("conflicting ", kind, " name: ", name,
+ " and ", name_data);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status ValidateCppIdent(StringPiece ident, StringPiece msg) {
+ if (ident.empty()) {
+ return errors::InvalidArgument("empty identifier: ", msg);
+ }
+ // Require that the identifier starts with a nondigit, and is composed of
+ // nondigits and digits, as specified in section [2.11 Identifiers] of the
+ // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
+ // defined as [0-9].
+ //
+ // Technically the standard also allows for `universal-character-name`, with a
+ // table of allowed unicode ranges, as well as `other implementation-defined
+ // characters`. We disallow those here to give better error messages, at the
+ // expensive of being more restrictive than the standard.
+ if (ident[0] != '_' && !IsAlpha(ident[0])) {
+ return errors::InvalidArgument("illegal leading char: ", msg);
+ }
+ for (size_t pos = 1; pos < ident.size(); ++pos) {
+ if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
+ return errors::InvalidArgument("illegal char: ", msg);
+ }
+ }
+ return Status::OK();
+}
+
+Status ValidateConfig(const Config& config) {
+ std::set<string> names;
+ for (const Feed& feed : config.feed()) {
+ TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
+ TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
+ TF_RETURN_IF_ERROR(ValidateFeedFetchName("feed", feed.name(), &names));
+ }
+ TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
+ names.clear();
+ for (const Fetch& fetch : config.fetch()) {
+ TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
+ TF_RETURN_IF_ERROR(ValidateFeedFetchName("fetch", fetch.name(), &names));
+ }
+ TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
+ if (config.feed().empty() || config.fetch().empty()) {
+ return errors::InvalidArgument("feeds and fetches must be specified");
+ }
+ return Status::OK();
+}
+
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/tfcompile_util.h b/tensorflow/compiler/aot/tfcompile_util.h
new file mode 100644
index 0000000000..651d75d0d0
--- /dev/null
+++ b/tensorflow/compiler/aot/tfcompile_util.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+#define TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
+
+#include "tensorflow/compiler/aot/tfcompile.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace tfcompile {
+
+// ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is
+// appended to error messages.
+Status ValidateCppIdent(StringPiece ident, StringPiece msg);
+
+// ValidateConfig returns OK iff config is valid.
+Status ValidateConfig(const Config& config);
+
+} // namespace tfcompile
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_AOT_TFCOMPILE_UTIL_H_
diff --git a/tensorflow/compiler/aot/tfcompile_util_test.cc b/tensorflow/compiler/aot/tfcompile_util_test.cc
new file mode 100644
index 0000000000..108ab1eab7
--- /dev/null
+++ b/tensorflow/compiler/aot/tfcompile_util_test.cc
@@ -0,0 +1,185 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/aot/tfcompile_util.h"
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace tfcompile {
+namespace {
+
+void ExpectErrorContains(Status status, StringPiece str) {
+ EXPECT_NE(Status::OK(), status);
+ EXPECT_TRUE(StringPiece(status.error_message()).contains(str))
+ << "expected error: " << status.error_message() << " to contain: " << str;
+}
+
+TEST(ValidateCppIdent, Simple) {
+ TF_EXPECT_OK(ValidateCppIdent("a", ""));
+ TF_EXPECT_OK(ValidateCppIdent("abc", ""));
+ TF_EXPECT_OK(ValidateCppIdent("_abc", ""));
+ TF_EXPECT_OK(ValidateCppIdent("_abc123", ""));
+ // Make sure we didn't skip a valid letter or digit
+ string ident;
+ for (char c = 'a'; c <= 'z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = 'A'; c <= 'Z'; c++) {
+ ident.append(1, c);
+ }
+ for (char c = '0'; c <= '9'; c++) {
+ ident.append(1, c);
+ }
+ ident += "_";
+ TF_EXPECT_OK(ValidateCppIdent(ident, ""));
+
+ ExpectErrorContains(ValidateCppIdent("", ""), "empty identifier");
+ ExpectErrorContains(ValidateCppIdent(" ", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent("0", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent(".", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent(":", ""), "illegal leading char");
+ ExpectErrorContains(ValidateCppIdent("a.", ""), "illegal char");
+ ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
+ ExpectErrorContains(ValidateCppIdent("a:", ""), "illegal char");
+}
+
+TEST(ValidateConfig, Good) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ feed->mutable_id()->set_output_index(123);
+ feed->set_name("foo_debug");
+ feed = config.add_feed();
+ feed->mutable_id()->set_node_name("bar");
+ feed->mutable_id()->set_output_index(0);
+ Fetch* fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("baz");
+ fetch->mutable_id()->set_output_index(456);
+ fetch->set_name("baz_debug");
+ fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("banana");
+ fetch->mutable_id()->set_output_index(0);
+ TF_EXPECT_OK(ValidateConfig(config));
+}
+
+TEST(ValidateConfig, BadEmpty) {
+ Config config;
+ ExpectErrorContains(ValidateConfig(config),
+ "feeds and fetches must be specified");
+}
+
+TEST(ValidateConfig, BadNoFeed) {
+ Config config;
+ Fetch* fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("foo");
+ ExpectErrorContains(ValidateConfig(config),
+ "feeds and fetches must be specified");
+}
+
+TEST(ValidateConfig, BadNoFetch) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ ExpectErrorContains(ValidateConfig(config),
+ "feeds and fetches must be specified");
+}
+
+TEST(ValidateConfig, BadFeedNodeName) {
+ Config config;
+ config.add_feed();
+ ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
+}
+
+TEST(ValidateConfig, BadFeedOutputIndex) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ feed->mutable_id()->set_output_index(-1);
+ ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
+}
+
+TEST(ValidateConfig, BadFetchNodeName) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ config.add_fetch();
+ ExpectErrorContains(ValidateConfig(config), "node_name must be non-empty");
+}
+
+TEST(ValidateConfig, BadFetchOutputIndex) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ Fetch* fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("bar");
+ fetch->mutable_id()->set_output_index(-1);
+ ExpectErrorContains(ValidateConfig(config), "output_index must be positive");
+}
+
+TEST(ValidateConfig, DuplicateFeedName) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ feed->set_name("dup");
+ feed = config.add_feed();
+ feed->mutable_id()->set_node_name("bar");
+ feed->set_name("dup");
+ ExpectErrorContains(ValidateConfig(config), "duplicate feed name");
+}
+
+TEST(ValidateConfig, DuplicateFetchName) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ Fetch* fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("bar");
+ fetch->set_name("dup");
+ fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("baz");
+ fetch->set_name("dup");
+ ExpectErrorContains(ValidateConfig(config), "duplicate fetch name");
+}
+
+TEST(ValidateConfig, ConflictingFeedName) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ feed->set_name("conflict");
+ feed = config.add_feed();
+ feed->mutable_id()->set_node_name("bar");
+ feed->set_name("conflict_data");
+ ExpectErrorContains(ValidateConfig(config), "conflicting feed name");
+}
+
+TEST(ValidateConfig, ConflictingFetchName) {
+ Config config;
+ Feed* feed = config.add_feed();
+ feed->mutable_id()->set_node_name("foo");
+ Fetch* fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("bar");
+ fetch->set_name("conflict");
+ fetch = config.add_fetch();
+ fetch->mutable_id()->set_node_name("baz");
+ fetch->set_name("conflict_data");
+ ExpectErrorContains(ValidateConfig(config), "conflicting fetch name");
+}
+
+} // namespace
+} // namespace tfcompile
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
new file mode 100644
index 0000000000..d2cfca207f
--- /dev/null
+++ b/tensorflow/compiler/jit/BUILD
@@ -0,0 +1,282 @@
+licenses(["notice"]) # Apache 2.0
+
+package_group(
+ name = "internal",
+ includes = [
+ "//tensorflow/compiler/tf2xla:internal",
+ ],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/tf2xla:friends",
+ ],
+)
+
+package(
+ default_visibility = [":internal"],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+
+# Target that bundles up the XLA CPU and GPU JIT devices.
+cc_library(
+ name = "jit",
+ visibility = [":friends"],
+ deps = [
+ ":xla_cpu_device",
+ ":xla_cpu_jit",
+ ":xla_gpu_device",
+ ":xla_gpu_jit",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xla_cpu_jit",
+ visibility = [":friends"],
+ deps = [
+ ":jit_compilation_passes",
+ ":xla_local_launch_op",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xla_gpu_jit",
+ visibility = [":friends"],
+ deps = [
+ ":jit_compilation_passes",
+ ":xla_local_launch_op",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla/service:gpu_plugin",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xla_cpu_device",
+ srcs = ["xla_cpu_device.cc"],
+ visibility = [":friends"],
+ deps = [
+ ":jit_compilation_passes",
+ ":xla_device",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xla_gpu_device",
+ srcs = ["xla_gpu_device.cc"],
+ visibility = [":friends"],
+ deps = [
+ ":jit_compilation_passes",
+ ":xla_device",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla/service:gpu_plugin",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+# Internal targets below this point.
+
+cc_library(
+ name = "common",
+ srcs = [
+ "defs.cc",
+ ],
+ hdrs = [
+ "defs.h",
+ ],
+ visibility = [":friends"],
+)
+
+cc_library(
+ name = "xla_device",
+ srcs = [
+ "xla_device.cc",
+ "xla_device_context.cc",
+ "xla_device_launch_op.cc",
+ "xla_device_ops.cc",
+ ],
+ hdrs = [
+ "xla_device.h",
+ "xla_device_context.h",
+ "xla_device_launch_op.h",
+ "xla_device_ops.h",
+ ],
+ deps = [
+ ":common",
+ ":jit_compilation_passes",
+ ":xla_compilation_cache",
+ ":xla_local_launch_op",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core/kernels:assign_op",
+ "//tensorflow/core/kernels:constant_op",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:identity_op",
+ "//tensorflow/core/kernels:no_op",
+ "//tensorflow/core/kernels:sendrecv_ops",
+ "//tensorflow/core/kernels:variable_ops",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "xla_compilation_cache",
+ srcs = ["xla_compilation_cache.cc"],
+ hdrs = ["xla_compilation_cache.h"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "jit_compilation_passes",
+ srcs = ["jit_compilation_pass_registration.cc"],
+ deps = [
+ ":compilation_passes",
+ "//tensorflow/core:core_cpu_internal",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "compilation_passes",
+ srcs = [
+ "build_xla_launch_ops_pass.cc",
+ "encapsulate_subgraphs_pass.cc",
+ "graph_to_functiondef.cc",
+ "mark_for_compilation_pass.cc",
+ ],
+ hdrs = [
+ "build_xla_launch_ops_pass.h",
+ "encapsulate_subgraphs_pass.h",
+ "graph_to_functiondef.h",
+ "mark_for_compilation_pass.h",
+ ],
+ deps = [
+ ":common",
+ ":parallel_check_op",
+ ":xla_local_launch_op",
+ "//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/compiler/jit/legacy_flags:encapsulate_subgraphs_pass_flags",
+ "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
+ "//tensorflow/compiler/tf2xla:const_analysis",
+ "//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "compilation_passes_test",
+ size = "small",
+ srcs = [
+ "encapsulate_subgraphs_pass_test.cc",
+ "graph_to_functiondef_test.cc",
+ "mark_for_compilation_pass_test.cc",
+ ],
+ deps = [
+ ":compilation_passes",
+ ":xla_local_launch_op",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/kernels:xla_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
+ name = "xla_local_launch_op",
+ srcs = ["xla_local_launch_op.cc"],
+ hdrs = ["xla_local_launch_op.h"],
+ deps = [
+ ":common",
+ ":xla_compilation_cache",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:tensorflow_opensource",
+ ],
+ alwayslink = 1,
+)
+
+tf_kernel_library(
+ name = "parallel_check_op",
+ srcs = ["parallel_check_op.cc"],
+ visibility = [":friends"],
+ deps = [
+ "//tensorflow/compiler/jit/legacy_flags:parallel_check_op_flags",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
new file mode 100644
index 0000000000..8fde197400
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.cc
@@ -0,0 +1,215 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/jit/xla_local_launch_op.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+static Status BuildLaunchNode(
+ const string& nodename, const string& function_name,
+ const AttrValueMap& function_attr, const string& device_name,
+ const DataTypeVector& constant_dtypes, const DataTypeVector& arg_dtypes,
+ const DataTypeVector& result_dtypes, Graph* graph, Node** node) {
+ NodeDef def;
+ def.set_name(graph->NewName(nodename));
+ def.set_op("_XlaLaunch");
+ def.set_device(device_name);
+ AddNodeAttr("Tconstants", constant_dtypes, &def);
+ AddNodeAttr("Targs", arg_dtypes, &def);
+ AddNodeAttr("Tresults", result_dtypes, &def);
+ NameAttrList function;
+ function.set_name(function_name);
+ *function.mutable_attr() = function_attr;
+ AddNodeAttr("function", function, &def);
+
+ Status status;
+ *node = graph->AddNode(def, &status);
+ return status;
+}
+
+static Status ReplaceNodeWithXlaLaunch(Graph* graph, Node* node) {
+ VLOG(2) << "Replacing " << node->name() << " with XlaLaunch";
+
+ int num_constant_args;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->def(), kXlaNumConstantArgsAttr, &num_constant_args));
+
+ if (num_constant_args < 0 || num_constant_args > node->input_types().size()) {
+ return errors::InvalidArgument(
+ "Invalid number of constant arguments to XLA kernel");
+ }
+ DataTypeVector const_dtypes(node->input_types().begin(),
+ node->input_types().begin() + num_constant_args);
+ DataTypeVector arg_dtypes(node->input_types().begin() + num_constant_args,
+ node->input_types().end());
+
+ // Build a _XlaLaunch operator to execute the function body.
+ Node* launch_node;
+ TF_RETURN_IF_ERROR(
+ BuildLaunchNode(graph->NewName(node->name()), node->type_string(),
+ node->def().attr(), node->def().device(), const_dtypes,
+ arg_dtypes, node->output_types(), graph, &launch_node));
+ launch_node->set_assigned_device_name(node->assigned_device_name());
+
+ // Copy incoming edges to the launch node.
+ for (const Edge* edge : node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ graph->AddControlEdge(edge->src(), launch_node);
+ } else {
+ graph->AddEdge(edge->src(), edge->src_output(), launch_node,
+ edge->dst_input());
+ }
+ }
+
+ // Copy outgoing edges to the launch node.
+ std::vector<const Edge*> out_edges(node->out_edges().begin(),
+ node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ Node* dst = edge->dst();
+ int src_output = edge->src_output();
+ int dst_input = edge->dst_input();
+ graph->RemoveEdge(edge);
+
+ if (edge->IsControlEdge()) {
+ graph->AddControlEdge(launch_node, dst);
+ } else {
+ graph->AddEdge(launch_node, src_output, dst, dst_input);
+ }
+ }
+ graph->RemoveNode(node);
+
+ return Status::OK();
+}
+
+Status BuildXlaLaunchOpsPass::Run(const GraphOptimizationPassOptions& options) {
+ Graph* graph = options.graph->get();
+
+ for (Node* n : graph->nodes()) {
+ // In all cases, only try to compile computational nodes.
+ if (!n->IsOp() || n->IsSend() || n->IsRecv() || n->IsControlFlow()) {
+ continue;
+ }
+
+ // Only compile nodes that are marked for compilation by the
+ // compilation-marking pass (via 'attr_name').
+ if (IsXlaCompiledKernel(*n)) {
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaLaunch(graph, n));
+ }
+ }
+ return Status::OK();
+}
+
+namespace {
+
+// Givens a NodeDef 'ndef' and the function library runtime 'flr', if
+// 'ndef' is a call to a compilable function defined in 'flr', returns OK
+// and fills in 'kernel' with a XlaLaunchOp kernel which computes the
+// node. Otherwise, returns a non-OK.
+//
+// This routine is here so that FunctionLibraryRuntime can jit a
+// specific function call as requested.
+Status CreateXlaLaunchOp(FunctionLibraryRuntime* flr, const NodeDef& ndef,
+ std::unique_ptr<OpKernel>* kernel) {
+ bool xla_compile = false;
+ if (!flr->GetFunctionLibraryDefinition()
+ ->GetAttr(ndef, kXlaCompileAttr, &xla_compile)
+ .ok() ||
+ !xla_compile) {
+ // Not marked as _XlaCompile=true.
+ return errors::InvalidArgument("No ", kXlaCompileAttr, " for ", ndef.op());
+ }
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterJitKernels();
+ if (!IsCompilable(flr, ndef)) {
+ // ndef is calling a function that XLA can't compile.
+ return errors::InvalidArgument("Not compilable: ", ndef.ShortDebugString());
+ }
+ FunctionLibraryRuntime::Handle handle;
+ // If ndef is not instantiable, e.g., the function does not exist,
+ // simply bail out.
+ TF_RETURN_IF_ERROR(flr->Instantiate(ndef.op(), ndef.attr(), &handle));
+ const FunctionBody* fbody = flr->GetFunctionBody(handle);
+ CHECK(fbody); // Can't be nullptr since we just instantiated it.
+ std::vector<bool> const_args(fbody->arg_types.size());
+ // If we can't analyze the const args. Bail out.
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*(fbody->graph), &const_args));
+
+ for (int i = 0; i < const_args.size(); ++i) {
+ if (const_args[i]) {
+ // There is a const arg. Bail out.
+ return errors::InvalidArgument("Const arg: ", i, " in ",
+ DebugString(fbody->fdef));
+ }
+ }
+
+ NodeDef launch_def;
+ launch_def.set_name(ndef.name());
+ launch_def.set_op("_XlaLaunch");
+ launch_def.set_device(flr->device()->name());
+ AddNodeAttr("Tconstants", DataTypeVector{}, &launch_def);
+ AddNodeAttr("Targs", fbody->arg_types, &launch_def);
+ AddNodeAttr("Tresults", fbody->ret_types, &launch_def);
+ NameAttrList func;
+ func.set_name(ndef.op());
+ *(func.mutable_attr()) = ndef.attr();
+ AddNodeAttr("function", func, &launch_def);
+
+ // TODO(b/32387911): Handles the host memory types across function
+ // calls properly. For now, we assume all inputs and outputs are on
+ // the device memory.
+ MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
+ MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
+
+ Device* dev = flr->device();
+ Status s;
+ OpKernelConstruction construction(
+ DeviceType(dev->device_type()), dev,
+ dev->GetAllocator(AllocatorAttributes()), &launch_def,
+ &fbody->fdef.signature(), flr, fbody->arg_types, input_memory_types,
+ fbody->ret_types, output_memory_types, flr->graph_def_version(), &s);
+ kernel->reset(new XlaLocalLaunchOp(&construction));
+ return s;
+}
+
+bool RegisterLaunchOpCreator() {
+ RegisterDefaultCustomKernelCreator(CreateXlaLaunchOp);
+ return true;
+}
+
+static bool register_me = RegisterLaunchOpCreator();
+
+} // end namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/build_xla_launch_ops_pass.h b/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
new file mode 100644
index 0000000000..1dfea93f02
--- /dev/null
+++ b/tensorflow/compiler/jit/build_xla_launch_ops_pass.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class BuildXlaLaunchOpsPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_BUILD_XLA_LAUNCH_OPS_PASS_H_
diff --git a/tensorflow/compiler/jit/defs.cc b/tensorflow/compiler/jit/defs.cc
new file mode 100644
index 0000000000..b20ad53ef6
--- /dev/null
+++ b/tensorflow/compiler/jit/defs.cc
@@ -0,0 +1,22 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/defs.h"
+
+namespace tensorflow {
+
+const char* const kXlaCompileAttr = "_XlaCompile";
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/defs.h b/tensorflow/compiler/jit/defs.h
new file mode 100644
index 0000000000..ddc830cb77
--- /dev/null
+++ b/tensorflow/compiler/jit/defs.h
@@ -0,0 +1,29 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Provides definitions needed for use of the TensorFlow XLA
+// device.
+
+#ifndef TENSORFLOW_COMPILER_JIT_DEFS_H_
+#define TENSORFLOW_COMPILER_JIT_DEFS_H_
+
+namespace tensorflow {
+
+// Name of attribute used to tag operators for compilation with XLA
+extern const char* const kXlaCompileAttr; // "_XlaCompile"
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_DEFS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
new file mode 100644
index 0000000000..72c440abe8
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -0,0 +1,660 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+
+#include <functional>
+#include <numeric>
+
+#include "tensorflow/compiler/jit/graph_to_functiondef.h"
+#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+const char* const kXlaCompiledKernelAttr = "_XlaCompiledKernel";
+const char* const kXlaNumConstantArgsAttr = "_XlaNumConstantArgs";
+
+namespace {
+
+// A node/slot pair.
+// TODO(phawkins): is there a common definition of this?
+struct NodeSlot {
+ NodeSlot() : node(nullptr), slot(-1) {}
+ NodeSlot(const Node* node, int slot) : node(node), slot(slot) {}
+
+ const Node* node;
+ int slot;
+
+ bool operator==(const NodeSlot& other) const {
+ return node == other.node && slot == other.slot;
+ }
+
+ struct Hasher {
+ uint64 operator()(NodeSlot const& s) const {
+ return Hash64Combine(std::hash<const Node*>()(s.node),
+ std::hash<int>()(s.slot));
+ }
+ };
+
+ struct PairHasher {
+ uint64 operator()(std::pair<NodeSlot, NodeSlot> const& s) const {
+ return Hash64Combine(Hasher()(s.first), Hasher()(s.second));
+ }
+ };
+};
+
+class Encapsulator {
+ public:
+ Encapsulator(string group_attribute, Graph const* graph_in)
+ : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {}
+
+ // Find subgraphs marked with 'group_attribute', and build a new
+ // subgraph, one for each value of 'group_attribute'.
+ Status SplitIntoSubgraphs();
+
+ // Build a FunctionDef for each subgraph, and add it 'library'. The values of
+ // the 'group_attribute' annotations become the function names.
+ // If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
+ // function conversion.
+ Status BuildFunctionDefs(const RewriteSubgraphFn& rewrite_subgraph_fn,
+ FunctionLibraryDefinition* library);
+
+ // Write a copy of the input graph to 'graph_out', where the subgraphs are
+ // replaced with calls to the new functions.
+ Status BuildOutputGraph(bool parallel_checking, Graph* graph_out);
+
+ private:
+ // Returns the key attribute associated with a node. Returns the empty string
+ // if no key attribute is found.
+ string GetFunctionNameAttr(const Node* node) const;
+
+ // A subgraph of the input, all marked with a common 'group_attribute'
+ // value.
+ struct Subgraph {
+ // The subgraph extracted from the input graph, suitable for being turned
+ // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are
+ // returned by _Retval nodes.
+ std::unique_ptr<Graph> graph;
+
+ // Which device are these nodes on? Used both to check that all nodes
+ // are assigned to the same device, and to assign a device to the call node.
+ string device;
+
+ // NodeDef for the function call node.
+ NodeDef call_node_def;
+
+ // Function call node(s) in the output graph. Not owned.
+ // If parallel_checking is enabled, 'call_node_inputs' is the function call
+ // node to which inputs should be fed, and 'call_node_outputs' is the
+ // parallel check op from which outputs should be read. If parallel checking
+ // is disabled, both point to the function call node.
+ Node* call_node_inputs;
+ Node* call_node_outputs;
+
+ // Maps from source (producer node/slot) and destination
+ // (consumer node/slot) tensors in the input graph to _Arg numbers in
+ // the subgraph. The source map is one-to-one, whereas the dest map may be
+ // many-to-one.
+ std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_src;
+ std::unordered_map<NodeSlot, int, NodeSlot::Hasher> args_by_dst;
+
+ // The _Arg nodes in the subgraph, in order by argument number.
+ std::vector<Node*> args;
+
+ // Map from source tensor in the input graph to result #.
+ std::unordered_map<NodeSlot, int, NodeSlot::Hasher> results;
+ };
+
+ // Builds a ParallelCheck op that compares the output of the original subgraph
+ // with the encapsulated subgraph.
+ Status BuildParallelCheckOp(
+ const std::unordered_map<const Node*, Node*>& node_images,
+ const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op);
+
+ const string group_attribute_;
+ const Graph* graph_in_;
+
+ std::unordered_map<string, Subgraph> subgraphs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator);
+};
+
+// TODO(phawkins) add a canonical copy of these operator names and refactor
+// everything to use it.
+static const char* const kArgOp = "_Arg";
+static const char* const kRetValOp = "_Retval";
+
+// Returns the function name attached to 'node', or the empty string if there is
+// none.
+string Encapsulator::GetFunctionNameAttr(Node const* node) const {
+ string attr;
+ if (!GetNodeAttr(node->def(), group_attribute_, &attr).ok()) {
+ attr.clear();
+ }
+ return attr;
+}
+
+Status Encapsulator::SplitIntoSubgraphs() {
+ Status s;
+
+ // Map from input graph nodes to subgraph nodes.
+ std::unordered_map<Node*, Node*> node_images;
+
+ // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes.
+ for (Node* node : graph_in_->nodes()) {
+ if (node->IsSource() || node->IsSink()) continue;
+ string func_id = GetFunctionNameAttr(node);
+ if (func_id.empty()) continue;
+
+ Subgraph& subgraph = subgraphs_[func_id];
+ if (!subgraph.graph) {
+ subgraph.graph.reset(new Graph(graph_in_->op_registry()));
+ subgraph.graph->set_versions(graph_in_->versions());
+ }
+
+ Node* image = subgraph.graph->CopyNode(node);
+ image->ClearAttr(group_attribute_);
+ node_images[node] = image;
+
+ // Check the device matches any existing device.
+ string device = node->assigned_device_name().empty()
+ ? node->def().device()
+ : node->assigned_device_name();
+
+ if (subgraph.device.empty()) {
+ subgraph.device = device;
+ } else if (subgraph.device != device) {
+ s.Update(errors::InvalidArgument(
+ "Mismatched devices for nodes to be grouped by Encapsulator"));
+ }
+ }
+
+ // Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for
+ // data edges that cross subgraph boundaries.
+ for (const Edge* edge : graph_in_->edges()) {
+ string src_func_id = GetFunctionNameAttr(edge->src());
+ string dst_func_id = GetFunctionNameAttr(edge->dst());
+ Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr);
+ Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr);
+
+ // Copy edges that are local to a subgraph.
+ if (!src_func_id.empty() && src_func_id == dst_func_id) {
+ Graph* g = subgraphs_[src_func_id].graph.get();
+ if (edge->IsControlEdge()) {
+ g->AddControlEdge(src_image, dst_image);
+ } else {
+ g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input());
+ }
+ continue;
+ }
+
+ // Ignore cross-boundary control edges for right now. We will lift them
+ // onto the enclosing call operators in BuildOutputGraph().
+ if (edge->IsControlEdge()) continue;
+
+ // Add 'src' as an output of its subgraph, if applicable.
+ if (!src_func_id.empty()) {
+ Subgraph& src_subgraph = subgraphs_[src_func_id];
+ int ret_index = src_subgraph.results.size();
+ if (src_subgraph.results
+ .emplace(NodeSlot(edge->src(), edge->src_output()), ret_index)
+ .second) {
+ // Create a new _Retval node
+ DataType dtype = edge->src()->output_type(edge->src_output());
+
+ NodeDef ret_def;
+ ret_def.set_op(kRetValOp);
+ ret_def.set_name(src_subgraph.graph->NewName("output"));
+ AddNodeAttr("T", dtype, &ret_def);
+ AddNodeAttr("index", ret_index, &ret_def);
+ Node* ret = src_subgraph.graph->AddNode(ret_def, &s);
+ if (!s.ok()) return s;
+
+ // Add an edge from 'src' to _Retval.
+ src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0);
+ }
+ }
+
+ // Add 'dst' as an input of its subgraph, if applicable.
+ if (!dst_func_id.empty()) {
+ Subgraph& dst_subgraph = subgraphs_[dst_func_id];
+
+ // Create an _Arg node for this tensor, if none exists yet.
+ std::unordered_map<NodeSlot, int, NodeSlot::Hasher>::iterator iter;
+ bool inserted;
+ std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace(
+ NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size());
+ int arg_index = iter->second;
+ if (inserted) {
+ // This is the first time we have seen this tensor. Create an _Arg node.
+ DataType dtype = edge->dst()->input_type(edge->dst_input());
+
+ NodeDef arg_def;
+ NodeDefBuilder builder(dst_subgraph.graph->NewName("input"), kArgOp);
+ builder.Attr("T", dtype);
+ builder.Attr("index", arg_index);
+ s = builder.Finalize(&arg_def);
+ if (!s.ok()) return s;
+
+ Node* arg = dst_subgraph.graph->AddNode(arg_def, &s);
+ if (!s.ok()) return s;
+
+ dst_subgraph.args.push_back(arg);
+ }
+ // Add an edge from the _Arg node to 'dst' in the subgraph.
+ dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] =
+ arg_index;
+ dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image,
+ edge->dst_input());
+ }
+ }
+
+ for (auto& entry : subgraphs_) {
+ FixupSourceAndSinkEdges(entry.second.graph.get());
+ }
+
+ return s;
+}
+
+Status Encapsulator::BuildFunctionDefs(
+ const RewriteSubgraphFn& rewrite_subgraph_fn,
+ FunctionLibraryDefinition* library) {
+ // For each subgraph, build a FunctionDef.
+ for (auto& subgraph_entry : subgraphs_) {
+ const string& name = subgraph_entry.first;
+ Subgraph& subgraph = subgraph_entry.second;
+
+ subgraph.call_node_def.set_op(name);
+ subgraph.call_node_def.set_name(name);
+ subgraph.call_node_def.set_device(subgraph.device);
+
+ if (rewrite_subgraph_fn) {
+ // Initialize the input and output permutations to the identity.
+ std::vector<int> input_permutation(subgraph.args_by_src.size());
+ std::iota(input_permutation.begin(), input_permutation.end(), 0);
+ std::vector<int> output_permutation(subgraph.results.size());
+ std::iota(output_permutation.begin(), output_permutation.end(), 0);
+
+ TF_RETURN_IF_ERROR(
+ rewrite_subgraph_fn(&subgraph.graph, &input_permutation,
+ &output_permutation, &subgraph.call_node_def));
+
+ // Apply the input/output permutations to the 'args_by_...' and 'results'
+ // mappings in 'subgraph', so when we build edges in BuildOutputGraph() we
+ // connect them to the right input/output positions.
+ if (input_permutation.size() != subgraph.args_by_src.size()) {
+ return errors::InvalidArgument("Input permutation has incorrect size.");
+ }
+ if (output_permutation.size() != subgraph.results.size()) {
+ return errors::InvalidArgument(
+ "Output permutation has incorrect size.");
+ }
+ for (auto& arg : subgraph.args_by_src) {
+ arg.second = input_permutation[arg.second];
+ }
+ for (auto& arg : subgraph.args_by_dst) {
+ arg.second = input_permutation[arg.second];
+ }
+ for (auto& result : subgraph.results) {
+ result.second = output_permutation[result.second];
+ }
+ }
+
+ FunctionDef fdef;
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef));
+
+ if (VLOG_IS_ON(1)) {
+ VLOG(2) << "Build function def " << name;
+ dump_graph::DumpGraphToFile(
+ strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph,
+ library);
+ dump_graph::DumpFunctionDefToFile(
+ strings::StrCat("encapsulate_fdef_", name), fdef);
+ }
+
+ TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef));
+ }
+ return Status::OK();
+}
+
+Status Encapsulator::BuildParallelCheckOp(
+ const std::unordered_map<const Node*, Node*>& node_images,
+ const Encapsulator::Subgraph& subgraph, Graph* graph_out,
+ Node** parallel_check_op) {
+ // Build an index mapping output positions to node/slot pairs in the
+ // original graph.
+ std::vector<NodeSlot> results_by_num(subgraph.results.size());
+ for (const auto& entry : subgraph.results) {
+ results_by_num[entry.second] = entry.first;
+ }
+
+ // Build a parallel check NodeDef.
+ int num_results = results_by_num.size();
+ std::vector<DataType> result_dtypes(num_results);
+ std::vector<NodeDefBuilder::NodeOut> expected_outputs(num_results);
+ std::vector<NodeDefBuilder::NodeOut> actual_outputs(num_results);
+ for (int i = 0; i < num_results; ++i) {
+ const NodeSlot& node_slot = results_by_num[i];
+ result_dtypes[i] = node_slot.node->output_type(node_slot.slot);
+ expected_outputs[i] =
+ NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(),
+ node_slot.slot, result_dtypes[i]);
+ actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(),
+ i, result_dtypes[i]);
+ }
+ // Assign the parallel check op to a CPU on the same task as the cluster it is
+ // checking.
+ string device, dummy;
+ if (!DeviceNameUtils::SplitDeviceName(
+ subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) {
+ return errors::InvalidArgument("Could not parse device name");
+ }
+ strings::StrAppend(&device, "/cpu:0");
+
+ NodeDef check_def;
+ TF_RETURN_IF_ERROR(
+ NodeDefBuilder(graph_out->NewName(strings::StrCat(
+ subgraph.call_node_def.name(), "_parallel_check")),
+ "ParallelCheck")
+ .Device(device)
+ .Attr("T", result_dtypes)
+ .Input(expected_outputs)
+ .Input(actual_outputs)
+ .Finalize(&check_def));
+
+ Status s;
+ Node* check_op = graph_out->AddNode(check_def, &s);
+ if (!s.ok()) return s;
+ check_op->set_assigned_device_name(device);
+
+ // TODO(phawkins): it seems redundant to call AddEdge as well as
+ // pass Inputs to the NodeDefBuilder, but I have been unable to find a
+ // way to avoid it.
+ for (int i = 0; i < num_results; ++i) {
+ const NodeSlot& node_slot = results_by_num[i];
+ graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op,
+ i);
+ graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i);
+ }
+
+ *parallel_check_op = check_op;
+ return Status::OK();
+}
+
+Status Encapsulator::BuildOutputGraph(bool parallel_checking,
+ Graph* graph_out) {
+ Status s;
+
+ // Map from nodes in the input graph to nodes in the output graph.
+ std::unordered_map<const Node*, Node*> node_images;
+
+ // Copy all unmarked nodes to the output graph.
+ for (Node* node : graph_in_->nodes()) {
+ if (node->IsSource() || node->IsSink()) continue;
+ string func_id = GetFunctionNameAttr(node);
+
+ // Don't copy nodes that going to be encapsulated, unless parallel checking
+ // is enabled.
+ if (!func_id.empty() && !parallel_checking) continue;
+
+ Node* image = graph_out->CopyNode(node);
+ node_images[node] = image;
+ }
+ node_images[graph_in_->source_node()] = graph_out->source_node();
+ node_images[graph_in_->sink_node()] = graph_out->sink_node();
+
+ // Add function call nodes for each subgraph.
+ for (auto& subgraph_entry : subgraphs_) {
+ Subgraph& subgraph = subgraph_entry.second;
+
+ subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s);
+ if (!s.ok()) return s;
+
+ // Copy the assigned device and the key_annotation over.
+ subgraph.call_node_inputs->set_assigned_device_name(subgraph.device);
+ subgraph.call_node_outputs = subgraph.call_node_inputs;
+
+ if (parallel_checking) {
+ TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out,
+ &subgraph.call_node_outputs));
+ }
+ }
+
+ // Set of edges already added to the output graph, represented as (src, dst)
+ // pairs. We use the set to deduplicate edges; multiple edges in the input
+ // graph may map to one edge in the output graph.
+ std::unordered_set<std::pair<NodeSlot, NodeSlot>, NodeSlot::PairHasher>
+ edges_added;
+
+ // Add edges to the graph_out graph.
+ for (const Edge* edge : graph_in_->edges()) {
+ string src_func_id = GetFunctionNameAttr(edge->src());
+ string dst_func_id = GetFunctionNameAttr(edge->dst());
+
+ // Ignore edges that are strictly contained within one subgraph, unless
+ // we are constructing parallel check graphs.
+ if (!src_func_id.empty() && src_func_id == dst_func_id) {
+ if (parallel_checking) {
+ Node* src_image = node_images.at(edge->src());
+ Node* dst_image = node_images.at(edge->dst());
+ if (edge->IsControlEdge()) {
+ graph_out->AddControlEdge(src_image, dst_image);
+ } else {
+ graph_out->AddEdge(src_image, edge->src_output(), dst_image,
+ edge->dst_input());
+ }
+ }
+ continue;
+ }
+
+ // We have an edge that crosses a cluster boundary.
+ Node* src_image = src_func_id.empty()
+ ? node_images.at(edge->src())
+ : subgraphs_.at(src_func_id).call_node_outputs;
+ Node* dst_image = dst_func_id.empty()
+ ? node_images.at(edge->dst())
+ : subgraphs_.at(dst_func_id).call_node_inputs;
+
+ // Copy control edges. Lift control edges onto the enclosing call operator.
+ if (edge->IsControlEdge()) {
+ // Add the control edge, if we have not already added it.
+ if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1))
+ .second) {
+ graph_out->AddControlEdge(src_image, dst_image);
+ }
+
+ // If parallel checking is enabled, also add a control edge to the
+ // corresponding parallel check op.
+ if (parallel_checking) {
+ graph_out->AddControlEdge(src_image, node_images.at(edge->dst()));
+ }
+ continue;
+ }
+
+ int src_output = edge->src_output();
+ if (!src_func_id.empty()) {
+ // 'src' is in a subgraph. Use the corresponding call output instead.
+ const Subgraph& src_subgraph = subgraphs_.at(src_func_id);
+ src_output =
+ src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output()));
+ }
+
+ int dst_input = edge->dst_input();
+
+ if (!dst_func_id.empty()) {
+ // 'dst' is in a subgraph. Use the corresponding call input instead.
+ const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id);
+ dst_input =
+ dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input()));
+
+ // If we are parallel checking, also feed the tensor as an input to the
+ // corresponding parallel check subgraph.
+ if (parallel_checking) {
+ graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()),
+ edge->dst_input());
+ }
+ }
+ // Add the edge, if we have not already added it.
+ if (edges_added
+ .emplace(NodeSlot(src_image, src_output),
+ NodeSlot(dst_image, dst_input))
+ .second) {
+ graph_out->AddEdge(src_image, src_output, dst_image, dst_input);
+ }
+ }
+
+ return s;
+}
+
+} // anonymous namespace
+
+Status EncapsulateSubgraphsInFunctions(
+ string group_attribute, const Graph& graph_in,
+ const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
+ std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library) {
+ Status s;
+
+ Encapsulator encapsulator(std::move(group_attribute), &graph_in);
+ s = encapsulator.SplitIntoSubgraphs();
+ if (!s.ok()) return s;
+
+ s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, library);
+ if (!s.ok()) return s;
+
+ std::unique_ptr<Graph> out(new Graph(library));
+ out->set_versions(graph_in.versions());
+ s = encapsulator.BuildOutputGraph(parallel_checking, out.get());
+ if (!s.ok()) return s;
+
+ *graph_out = std::move(out);
+ return s;
+}
+
+// Renumber the indices of _Arg nodes in a graph, according to
+// 'permutation' that maps old indices to new indices.
+static Status RenumberArguments(Graph* graph,
+ const std::vector<int>& permutation) {
+ for (Node* n : graph->nodes()) {
+ if (n->type_string() == kArgOp) {
+ int index;
+ TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), "index", &index));
+ if (index < 0 || index >= permutation.size()) {
+ return errors::InvalidArgument("Invalid argument number");
+ }
+ n->AddAttr("index", permutation[index]);
+ }
+ }
+ return Status::OK();
+}
+
+Status EncapsulateSubgraphsPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ VLOG(1) << "EncapsulateSubgraphsPass::Run";
+ legacy_flags::EncapsulateSubgraphsPassFlags* flags =
+ legacy_flags::GetEncapsulateSubgraphsPassFlags();
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("before_encapsulate_subgraphs", **options.graph,
+ options.flib_def);
+ }
+
+ std::unique_ptr<Graph> graph_out;
+ FunctionLibraryDefinition* const library = options.flib_def;
+
+ OptimizerOptions opts;
+ std::unique_ptr<FunctionLibraryRuntime> flr(
+ NewFunctionLibraryRuntime(nullptr, options.session_options->env, nullptr,
+ TF_GRAPH_DEF_VERSION, library, opts));
+
+ auto rewrite_subgraph = [&flr](
+ std::unique_ptr<Graph>* subgraph, std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation, NodeDef* node) {
+ // Optimize the subgraph.
+ Graph* g = subgraph->release();
+ OptimizeGraph(flr.get(), &g);
+ subgraph->reset(g);
+
+ std::vector<bool> const_args(input_permutation->size());
+ TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*g, &const_args));
+
+ // Compute a permutation of the arguments such that the constant arguments
+ // are first.
+ const int num_consts =
+ std::count(const_args.begin(), const_args.end(), true);
+ int const_pos = 0;
+ int arg_pos = num_consts;
+ for (int i = 0; i < const_args.size(); ++i) {
+ if (const_args[i]) {
+ (*input_permutation)[i] = const_pos;
+ ++const_pos;
+ } else {
+ (*input_permutation)[i] = arg_pos;
+ ++arg_pos;
+ }
+ }
+
+ // Renumber argument nodes in the graph.
+ TF_RETURN_IF_ERROR(RenumberArguments(subgraph->get(), *input_permutation));
+
+ // TODO(phawkins): add a forward is-constant analysis, similarly split
+ // outputs into host-memory constants and device-memory non-constants.
+
+ AddNodeAttr(kXlaCompiledKernelAttr, true, node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_consts, node);
+ return Status::OK();
+ };
+
+ TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions(
+ kXlaClusterAttr, **options.graph, rewrite_subgraph,
+ flags->tf_xla_parallel_checking, &graph_out, library));
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out,
+ options.flib_def);
+ }
+
+ *options.graph = std::move(graph_out);
+ return Status::OK();
+}
+
+bool IsXlaCompiledKernel(const Node& node) {
+ bool is_compiled = false;
+ bool has_compilation_attr =
+ GetNodeAttr(node.def(), kXlaCompiledKernelAttr, &is_compiled).ok() &&
+ is_compiled;
+ return has_compilation_attr ? is_compiled : false;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
new file mode 100644
index 0000000000..ffd39f0b77
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h
@@ -0,0 +1,86 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// An optimization pass that groups nodes marked with a common
+// kXlaClusterAttr into functions, and replaces the original nodes by
+// calls. The calls are annotated with kXlaCompiledKernelAttr.
+
+#ifndef TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// A rewriting function to apply to each subgraph during encapsulation.
+// 'graph' is the subgraph. The rewriting may renumber the inputs and outputs;
+// 'input_permutation' is a mapping from old argument numbers to new argument
+// numbers, whereas 'output_permutation' is the same for outputs. Both
+// 'input_permutation' and 'output_permutation' are initialized to the identity
+// permutation. 'nodedef' is the NodeDef for the call to the function under
+// construction, provided to allow additional attributes to be set.
+typedef std::function<Status(
+ std::unique_ptr<Graph>* graph, std::vector<int>* input_permutation,
+ std::vector<int>* output_permutation, NodeDef* node_def)>
+ RewriteSubgraphFn;
+
+// Transformation that finds subgraphs whose nodes are marked with
+// 'group_attribute', splits those subgraphs into functions, and replaces
+// the originals with function calls.
+//
+// 'group_attribute' must be a string valued-attribute that names the new
+// functions to introduce.
+//
+// If 'rewrite_subgraph_fn' is set, it is applied to each subgraph before
+// function conversion.
+//
+// If 'parallel_checking' is true, the unencapsulated operators are added to the
+// output graph, together with a "ParallelCheck" operator, that verifies that
+// the original and encapsulated subgraphs produce similar results.
+//
+// TODO(phawkins): currently, some information in control edges
+// is not preserved. Suppose you have A and B in the main
+// graph, C and D in a subgraph. B and C have control deps from A, D has control
+// dep from B. Originally D must run after C, post-transformation this
+// dependency is lost.
+Status EncapsulateSubgraphsInFunctions(
+ string group_attribute, const Graph& graph_in,
+ const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking,
+ std::unique_ptr<Graph>* graph_out, FunctionLibraryDefinition* library);
+
+// The attribute that marks function calls produced by the encapsulate
+// subgraphs pass and that should in turn be compiled via _XlaLaunch operators.
+extern const char* const kXlaCompiledKernelAttr;
+
+// Does `node` have the kXlaCompiledKernelAttr attribute?
+bool IsXlaCompiledKernel(const Node& node);
+
+// Functions produce by the EncapsulateSubgraphs pass have their arguments
+// ordered such that compile-time constant arguments are first in the argument
+// order. The functions are annotated with the following attribute giving the
+// number of constant arguments.
+extern const char* const kXlaNumConstantArgsAttr;
+
+class EncapsulateSubgraphsPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_ENCAPSULATE_SUBGRAPHS_PASS_H_
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
new file mode 100644
index 0000000000..c85882e0d7
--- /dev/null
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -0,0 +1,397 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/graph/equal_graph_def.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
+ string* diff) {
+ // TODO(phawkins) use a more sophisticated equality test.
+ if (a.DebugString() != b.DebugString()) {
+ if (diff) {
+ *diff = strings::StrCat("Definition mismatch for function ",
+ a.signature().name(), ", expected:\n",
+ a.DebugString());
+ }
+ return false;
+ }
+ return true;
+}
+
+bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected,
+ const FunctionDefLibrary& actual, string* diff) {
+ std::unordered_map<string, const FunctionDef*> actual_index;
+ for (const FunctionDef& function : actual.function()) {
+ actual_index[function.signature().name()] = &function;
+ }
+
+ for (const FunctionDef& expected_function : expected.function()) {
+ auto it = actual_index.find(expected_function.signature().name());
+ if (it == actual_index.end()) {
+ if (diff) {
+ *diff = strings::StrCat("Did not find expected function '",
+ expected_function.signature().name(), "'");
+ }
+ return false;
+ }
+ if (!EqualFunctionDef(expected_function, *it->second, diff)) return false;
+ actual_index.erase(it);
+ }
+
+ if (!actual_index.empty()) {
+ if (diff != nullptr) {
+ *diff = strings::StrCat("Found unexpected function '",
+ actual_index.begin()->second->signature().name(),
+ "'");
+ }
+ return false;
+ }
+
+ return true;
+}
+
+#define TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(expected, actual) \
+ do { \
+ string diff; \
+ EXPECT_TRUE(EqualFunctionDefLibrary(actual, expected, &diff)) \
+ << diff << "\nActual: " << actual.DebugString(); \
+ } while (false)
+
+REGISTER_OP("InputTest").Output("o: float");
+
+REGISTER_OP("UnaryTest").Input("a: float").Output("o: float");
+REGISTER_OP("BinaryTest")
+ .Input("a: float")
+ .Input("b: float")
+ .Output("o: float");
+
+REGISTER_OP("AddNLikeTest")
+ .Input("inputs: N * T")
+ .Output("sum: T")
+ .Attr("N: int >= 1")
+ .Attr("T: numbertype")
+ .SetIsCommutative()
+ .SetIsAggregate();
+
+Node* Input(const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("InputTest", opts);
+}
+
+Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) {
+ return ops::UnaryOp("UnaryTest", a, opts);
+}
+
+Node* Binary(ops::NodeOut a, ops::NodeOut b,
+ const GraphDefBuilder::Options& opts) {
+ return ops::BinaryOp("BinaryTest", a, b, opts);
+}
+
+Node* AddNLike(std::vector<ops::NodeOut> inputs,
+ const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp("AddN"), "AddNLikeTest",
+ opts.op_registry());
+ node_builder.Input(inputs);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Node* ArgOp(int index, DataType type, const GraphDefBuilder::Options& opts) {
+ return ops::SourceOp("_Arg",
+ opts.WithAttr("T", type).WithAttr("index", index));
+}
+
+Node* RetOp(int index, ops::NodeOut a, const GraphDefBuilder::Options& opts) {
+ if (opts.HaveError()) return nullptr;
+ NodeBuilder node_builder(opts.GetNameForOp("Retval"), "_Retval",
+ opts.op_registry());
+ node_builder.Input(a).Attr("index", index);
+ return opts.FinalizeBuilder(&node_builder);
+}
+
+Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) {
+ Status s;
+ // Convert the GraphDef to a Graph
+ std::unique_ptr<FunctionLibraryDefinition> lib_def(
+ new FunctionLibraryDefinition(OpRegistry::Global(), *library));
+ GraphConstructorOptions options;
+ options.allow_internal_ops = true;
+ std::unique_ptr<Graph> graph(new Graph(lib_def.get()));
+ s = ConvertGraphDefToGraph(options, *graphdef, graph.get());
+ if (!s.ok()) return s;
+
+ std::unique_ptr<Graph> graph_out;
+ s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph,
+ /* rewrite_subgraph_fn= */ {},
+ /* parallel_checking= */ false,
+ &graph_out, lib_def.get());
+ if (!s.ok()) return s;
+
+ GraphDef graphdef_out;
+ graph_out->ToGraphDef(&graphdef_out);
+ graphdef->Swap(&graphdef_out);
+
+ *library = lib_def->ToProto();
+ return s;
+}
+
+// If there are no marked nodes, funcification should be a no-op.
+TEST(EncapsulateSubgraphsTest, NoFunctions) {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+
+ Node* a = Input(builder.opts().WithName("A"));
+ Node* b = Input(builder.opts().WithName("B"));
+ Node* c = Unary(a, builder.opts().WithName("C"));
+ Binary(b, c, builder.opts().WithName("D"));
+
+ GraphDef graphdef_in;
+ FunctionDefLibrary library_in;
+ builder.ToGraphDef(&graphdef_in);
+ *library_in.add_function() = test::function::XTimesTwo();
+
+ GraphDef graphdef_out = graphdef_in;
+ FunctionDefLibrary library_out = library_in;
+ TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out));
+
+ TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out);
+ TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out);
+}
+
+// Test with one function to transform.
+TEST(EncapsulateSubgraphsTest, OneFunction) {
+ FunctionDefLibrary library;
+ GraphDef graphdef;
+
+ {
+ *library.add_function() = test::function::XTimesTwo();
+
+ GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
+ Node* a = Input(b1.opts().WithName("A"));
+ Node* b = Input(b1.opts().WithName("B"));
+ // Give nodes 'c' and 'd' names that collide after lowercasing.
+ Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1"));
+ Node* d = Binary(b, c, b1.opts().WithName("c").WithControlInput(c).WithAttr(
+ "_encapsulate", "F1"));
+ Binary(a, d, b1.opts().WithName("E"));
+ b1.ToGraphDef(&graphdef);
+ }
+
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+
+ FunctionDefLibrary library_expected;
+ GraphDef graphdef_expected;
+
+ *library_expected.add_function() = test::function::XTimesTwo();
+ *library_expected.add_function() = FunctionDefHelper::Create(
+ "F1", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
+ {
+ {{"C"}, "UnaryTest", {"input__0"}},
+ {{"c"}, "BinaryTest", {"input__1", "C:o:0"}, {}, {"C"}},
+ },
+ {{"output__2", "c:o:0"}});
+
+ {
+ std::unique_ptr<FunctionLibraryDefinition> lib_def(
+ new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
+ GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
+ Node* a = Input(b2.opts().WithName("A"));
+ Node* b = Input(b2.opts().WithName("B"));
+
+ NodeBuilder node_builder("F1", "F1", lib_def.get());
+ node_builder.Input(a).Input(b);
+ Node* call = b2.opts().FinalizeBuilder(&node_builder);
+
+ Binary(a, call, b2.opts().WithName("E"));
+ b2.ToGraphDef(&graphdef_expected);
+ }
+
+ // If there are no marked nodes, funcification should be a no-op.
+ TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
+ TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
+}
+
+// Test with two functions to transform.
+TEST(EncapsulateSubgraphsTest, TwoFunctions) {
+ FunctionDefLibrary library;
+ GraphDef graphdef;
+
+ {
+ *library.add_function() = test::function::XTimesTwo();
+
+ GraphDefBuilder b1(GraphDefBuilder::kFailImmediately);
+ Node* a = Input(b1.opts().WithName("A"));
+ Node* b = Input(b1.opts().WithName("B"));
+ Node* control = Input(b1.opts().WithName("Control"));
+ Node* c =
+ Unary(a, b1.opts().WithName("C").WithControlInput(control).WithAttr(
+ "_encapsulate", "F1"));
+ Node* d =
+ Binary(b, c, b1.opts().WithName("D").WithControlInput(control).WithAttr(
+ "_encapsulate", "F2"));
+ Binary(a, d, b1.opts().WithName("E"));
+ b1.ToGraphDef(&graphdef);
+ }
+
+ TF_EXPECT_OK(Encapsulate(&graphdef, &library));
+
+ FunctionDefLibrary library_expected;
+ GraphDef graphdef_expected;
+
+ *library_expected.add_function() = test::function::XTimesTwo();
+ *library_expected.add_function() = FunctionDefHelper::Create(
+ "F1", {"input__0:float"}, {"output__1:float"}, {},
+ {
+ {{"C"}, "UnaryTest", {"input__0"}},
+ },
+ {{"output__1", "C:o:0"}});
+ *library_expected.add_function() = FunctionDefHelper::Create(
+ "F2", {"input__0:float", "input__1:float"}, {"output__2:float"}, {},
+ {
+ {{"D"}, "BinaryTest", {"input__0", "input__1"}},
+ },
+ {{"output__2", "D:o:0"}});
+
+ {
+ std::unique_ptr<FunctionLibraryDefinition> lib_def(
+ new FunctionLibraryDefinition(OpRegistry::Global(), library_expected));
+ GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get());
+ Node* a = Input(b2.opts().WithName("A"));
+ Node* b = Input(b2.opts().WithName("B"));
+ Node* control = Input(b2.opts().WithName("Control"));
+
+ NodeBuilder nb("F1", "F1", lib_def.get());
+ nb.Input(a).ControlInput(control);
+ Node* call1 = b2.opts().FinalizeBuilder(&nb);
+
+ NodeBuilder nb2("F2", "F2", lib_def.get());
+ nb2.Input(b).Input(call1).ControlInput(control);
+ Node* call2 = b2.opts().FinalizeBuilder(&nb2);
+
+ Binary(a, call2, b2.opts().WithName("E"));
+ b2.ToGraphDef(&graphdef_expected);
+ }
+
+ // If there are no marked nodes, funcification should be a no-op.
+ TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef);
+ TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library);
+}
+
+// Returns a vector of node names in 'graph', sorted by name.
+std::vector<string> GraphNodes(const Graph& graph) {
+ std::vector<string> nodes;
+ for (const auto& node : graph.nodes()) {
+ if (!node->IsSource() && !node->IsSink()) {
+ nodes.push_back(node->name());
+ }
+ }
+ std::sort(nodes.begin(), nodes.end());
+ return nodes;
+}
+
+// Returns a sorted vector of (src, dst) edges in 'graph'.
+std::vector<std::pair<string, string>> GraphEdges(const Graph& graph) {
+ std::vector<std::pair<string, string>> edges;
+ for (const Edge* edge : graph.edges()) {
+ if (edge->src()->IsSource() || edge->dst()->IsSink()) continue;
+ edges.emplace_back(
+ strings::StrCat(edge->src()->name(), ":", edge->src_output()),
+ strings::StrCat(edge->dst()->name(), ":", edge->dst_input()));
+ }
+ std::sort(edges.begin(), edges.end());
+ return edges;
+}
+
+TEST(EncapsulateSubgraphsTest, InputDeduplication) {
+ Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
+ "/job:localhost/replica:0/task:0/cpu:0");
+ auto x = ops::Placeholder(root.WithOpName("x"), DT_FLOAT);
+ auto add1 = ops::Add(root.WithOpName("add1"), x, x);
+ add1.node()->AddAttr("_cluster", "cluster1");
+ auto add2 = ops::Add(root.WithOpName("add2"), add1, add1);
+ add2.node()->AddAttr("_cluster", "cluster2");
+ auto out = ops::Mul(root.WithOpName("mul"), add1, add2);
+
+ Graph graph_before_encapsulation(OpRegistry::Global());
+ TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
+
+ FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph;
+ TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
+ "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
+ /*parallel_checking=*/false, &graph, &library));
+
+ std::vector<string> expected_nodes = {"cluster1", "cluster2", "mul", "x"};
+ EXPECT_EQ(expected_nodes, GraphNodes(*graph));
+
+ std::vector<std::pair<string, string>> expected_edges = {
+ {"cluster1:0", "cluster2:0"},
+ {"cluster1:0", "mul:0"},
+ {"cluster2:0", "mul:1"},
+ {"x:0", "cluster1:0"}};
+ EXPECT_EQ(expected_edges, GraphEdges(*graph));
+}
+
+TEST(EncapsulateSubgraphsTest, ParallelChecking) {
+ Scope root = Scope::NewRootScope().ExitOnError().WithDevice(
+ "/job:localhost/replica:0/task:0/cpu:0");
+ auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT);
+ auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT);
+ auto add1 = ops::Add(root.WithOpName("add1"), x1, x2);
+ add1.node()->AddAttr("_cluster", "cluster1");
+ auto add2 = ops::Add(root.WithOpName("add2"), add1, x2);
+ add2.node()->AddAttr("_cluster", "cluster1");
+ auto out = ops::Mul(root.WithOpName("mul"), x1, add2);
+
+ Graph graph_before_encapsulation(OpRegistry::Global());
+ TF_ASSERT_OK(root.ToGraph(&graph_before_encapsulation));
+
+ FunctionLibraryDefinition library(OpRegistry::Global(), {});
+ std::unique_ptr<Graph> graph;
+ TF_ASSERT_OK(EncapsulateSubgraphsInFunctions(
+ "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{},
+ /*parallel_checking=*/true, &graph, &library));
+
+ std::vector<string> expected_nodes = {
+ "add1", "add2", "cluster1", "cluster1_parallel_check/_0",
+ "mul", "x1", "x2"};
+ EXPECT_EQ(expected_nodes, GraphNodes(*graph));
+
+ std::vector<std::pair<string, string>> expected_edges = {
+ {"add1:0", "add2:0"},
+ {"add2:0", "cluster1_parallel_check/_0:0"},
+ {"cluster1:0", "cluster1_parallel_check/_0:1"},
+ {"cluster1_parallel_check/_0:0", "mul:1"},
+ {"x1:0", "add1:0"},
+ {"x1:0", "cluster1:0"},
+ {"x1:0", "mul:0"},
+ {"x2:0", "add1:1"},
+ {"x2:0", "add2:1"},
+ {"x2:0", "cluster1:1"},
+ };
+ EXPECT_EQ(expected_edges, GraphEdges(*graph));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/compiler/jit/graph_to_functiondef.cc
new file mode 100644
index 0000000000..a589a94fd4
--- /dev/null
+++ b/tensorflow/compiler/jit/graph_to_functiondef.cc
@@ -0,0 +1,274 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/graph_to_functiondef.h"
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+// TODO(phawkins) add a canonical copy of these operator names and refactor
+// everything to use it.
+const char* const kArgOp = "_Arg";
+const char* const kRetValOp = "_Retval";
+
+// Class that maintains a one-to-one original node name -> new name mapping.
+// We have to normalize the names used as input and output arguments to
+// match regexp "[a-z][a-z0-9_]*". Once we rename them, we risk creating
+// a name collision with the other node names, so if necessary we add
+// a suffix to make names unique. So if we have an input named "A" and a
+// node in the function body named "a", they will be renamed to "a" and "a_0".
+class NodeNameMapping {
+ public:
+ NodeNameMapping() = default;
+
+ // Normalize the input/output name and then make it unique.
+ string Normalize(const string& name);
+
+ // Make the node name unique.
+ string Uniquify(const string& name);
+
+ // Look up how a node name was previously normalized/uniquified.
+ // Returns empty if name was never seen.
+ string Renormalize(const string& name) const;
+
+ private:
+ string NormalizeHelper(string name) const;
+ string UniquifyHelper(string name);
+
+ std::unordered_set<string> used_names_;
+ std::unordered_map<string, string> name_mapping_;
+};
+
+string NodeNameMapping::NormalizeHelper(string name) const {
+ // Convert letters to lowercase and non-alphanumeric characters to '_'.
+ if (name.empty()) name = "unknown";
+ const int n = name.size();
+ for (int i = 0; i < n; i++) {
+ char c = name[i];
+ if (isalnum(c)) {
+ if (isupper(c)) {
+ name[i] = tolower(c);
+ }
+ } else {
+ name[i] = '_';
+ }
+ }
+ return name;
+}
+
+string NodeNameMapping::UniquifyHelper(string name) {
+ // If the name hasn't been used yet, use it as-is.
+ if (used_names_.insert(name).second) return name;
+ // Add a suffix to name to make it unique.
+ for (int i = 0;; ++i) {
+ const string candidate = strings::StrCat(name, "_", i);
+ if (used_names_.insert(candidate).second) return candidate;
+ }
+}
+
+string NodeNameMapping::Normalize(const string& name) {
+ const string normalized = UniquifyHelper(NormalizeHelper(name));
+ name_mapping_[name] = normalized;
+ return normalized;
+}
+
+string NodeNameMapping::Uniquify(const string& name) {
+ const string uniqued = UniquifyHelper(name);
+ name_mapping_[name] = uniqued;
+ return uniqued;
+}
+
+string NodeNameMapping::Renormalize(const string& name) const {
+ const auto iter = name_mapping_.find(name);
+ if (iter == name_mapping_.end()) return string();
+ return iter->second;
+}
+
+} // anonymous namespace
+
+// Graph to FunctionDef conversion. This code is closely modeled on the Python
+// code in third_party/tensorflow/python/framework/function.py.
+
+Status GraphToFunctionDef(const Graph& graph, const string& name,
+ FunctionDef* fdef) {
+ fdef->mutable_signature()->set_name(name);
+
+ std::unordered_map<string, string> tensor_renaming;
+ std::unordered_map<string, string> return_values;
+ NodeNameMapping node_names;
+
+ for (Node const* node : graph.nodes()) {
+ if (!node->IsOp()) continue;
+
+ if (node->type_string() == kArgOp) {
+ int index;
+ DataType type;
+ GetNodeAttr(node->def(), "T", &type);
+ GetNodeAttr(node->def(), "index", &index);
+ while (fdef->signature().input_arg_size() <= index) {
+ fdef->mutable_signature()->add_input_arg();
+ }
+ OpDef::ArgDef* argdef =
+ fdef->mutable_signature()->mutable_input_arg(index);
+ argdef->set_type(type);
+ const string normalized = node_names.Normalize(node->name());
+ argdef->set_name(normalized);
+ tensor_renaming[strings::StrCat(node->name(), ":0")] = normalized;
+ continue;
+ }
+
+ if (node->type_string() == kRetValOp) {
+ int index;
+ DataType type;
+ GetNodeAttr(node->def(), "T", &type);
+ GetNodeAttr(node->def(), "index", &index);
+ while (fdef->signature().output_arg_size() <= index) {
+ fdef->mutable_signature()->add_output_arg();
+ }
+ OpDef::ArgDef* argdef =
+ fdef->mutable_signature()->mutable_output_arg(index);
+ argdef->set_type(type);
+ const string normalized = node_names.Normalize(node->name());
+ argdef->set_name(normalized);
+ CHECK_EQ(node->in_edges().size(), 1);
+ Edge const* edge = *node->in_edges().begin();
+ return_values[normalized] =
+ strings::StrCat(edge->src()->name(), ":", edge->src_output());
+ continue;
+ }
+
+ NodeDef* node_def = fdef->add_node_def();
+ node_def->CopyFrom(node->def());
+ node_def->set_name(node_names.Uniquify(node->name()));
+ node_def->clear_device();
+
+ // Reset input names based on graph rather than the NodeDef.
+ node_def->clear_input();
+
+ // Edges, indexed by dst_input.
+ std::vector<const Edge*> in_edges;
+ std::vector<const Edge*> control_edges;
+ for (Edge const* edge : node->in_edges()) {
+ if (edge->src()->IsSource()) continue;
+
+ if (edge->IsControlEdge()) {
+ control_edges.push_back(edge);
+ } else {
+ if (in_edges.size() <= edge->dst_input()) {
+ in_edges.resize(edge->dst_input() + 1);
+ }
+ in_edges[edge->dst_input()] = edge;
+ }
+ }
+
+ // Add regular inputs
+ for (int i = 0; i < in_edges.size(); ++i) {
+ const Edge* edge = in_edges[i];
+ if (edge == nullptr) {
+ return errors::InvalidArgument(
+ "Nonconsecutive input edges; missing "
+ "input edge ",
+ i, " for node ", node->name());
+ }
+ node_def->add_input(
+ strings::StrCat(edge->src()->name(), ":", edge->src_output()));
+ }
+
+ // Add control inputs
+ for (const Edge* edge : control_edges) {
+ node_def->add_input(strings::StrCat("^", edge->src()->name()));
+ }
+
+ // Populate tensor_renaming.
+ NameRangeMap output_ranges;
+ TF_RETURN_IF_ERROR(NameRangesForNode(node->def(), node->op_def(), nullptr,
+ &output_ranges));
+ for (const auto& output : output_ranges) {
+ for (int i = output.second.first; i < output.second.second; ++i) {
+ const string tensor_name = strings::StrCat(
+ node_def->name(), ":", output.first, ":", i - output.second.first);
+ tensor_renaming[strings::StrCat(node->name(), ":", i)] = tensor_name;
+ }
+ }
+ }
+
+ // Detect missing function inputs.
+ for (int i = 0; i < fdef->signature().input_arg_size(); ++i) {
+ const string& input_name = fdef->signature().input_arg(i).name();
+ if (input_name.empty()) {
+ return errors::InvalidArgument("Missing input ", i, " to function ",
+ name);
+ }
+ }
+
+ // Remap input names. We do this as a second pass to allow the nodes to be in
+ // any order.
+ for (int n_index = 0; n_index < fdef->node_def_size(); ++n_index) {
+ NodeDef* node_def = fdef->mutable_node_def(n_index);
+ for (int i = 0; i < node_def->input_size(); ++i) {
+ if (StringPiece(node_def->input(i)).starts_with("^")) {
+ // Control input
+ const string normalized =
+ node_names.Renormalize(node_def->input(i).substr(1));
+ if (normalized.empty()) {
+ return errors::InvalidArgument(
+ "Could not remap control input ", i, ", '", node_def->input(i),
+ "', of node '", node_def->name(), "' in function ", name);
+ }
+ *node_def->mutable_input(i) = strings::StrCat("^", normalized);
+ } else {
+ const auto iter = tensor_renaming.find(node_def->input(i));
+ if (iter == tensor_renaming.end()) {
+ return errors::InvalidArgument(
+ "Could not remap input ", i, ", '", node_def->input(i),
+ "', of node '", node_def->name(), "' in function ", name);
+ }
+ *node_def->mutable_input(i) = iter->second;
+ }
+ }
+ }
+
+ // Remap return values.
+ for (int r = 0; r < fdef->signature().output_arg_size(); ++r) {
+ const string& ret_name = fdef->signature().output_arg(r).name();
+ if (ret_name.empty()) {
+ return errors::InvalidArgument("Missing output ", r, " to function ",
+ name);
+ }
+ const string& return_value = return_values[ret_name];
+ const auto iter = tensor_renaming.find(return_value);
+ if (iter == tensor_renaming.end()) {
+ return errors::InvalidArgument("Could not remap return value ", r, ", '",
+ ret_name, "', of '", return_value,
+ "' in function ", name);
+ }
+ (*fdef->mutable_ret())[ret_name] = iter->second;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/graph_to_functiondef.h b/tensorflow/compiler/jit/graph_to_functiondef.h
new file mode 100644
index 0000000000..3e1ae7bbbe
--- /dev/null
+++ b/tensorflow/compiler/jit/graph_to_functiondef.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
+#define TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Converts 'graph' to a FunctionDef 'fdef', with name 'name'.
+// Closely modeled on the Python code in
+// third_party/tensorflow/python/framework/function.py
+Status GraphToFunctionDef(const Graph& graph, const string& name,
+ FunctionDef* fdef);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
diff --git a/tensorflow/compiler/jit/graph_to_functiondef_test.cc b/tensorflow/compiler/jit/graph_to_functiondef_test.cc
new file mode 100644
index 0000000000..df45f455a9
--- /dev/null
+++ b/tensorflow/compiler/jit/graph_to_functiondef_test.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/graph_to_functiondef.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/graph/equal_graph_def.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b,
+ string* diff) {
+ // TODO(phawkins) use a more sophisticated equality test.
+ if (a.DebugString() != b.DebugString()) {
+ if (diff) {
+ *diff = strings::StrCat("Definition mismatch for function ",
+ a.signature().name(), ":\n", a.DebugString(),
+ "\n ---- vs. ----\n", b.DebugString());
+ }
+ return false;
+ }
+ return true;
+}
+
+TEST(GraphToFunctionDefTest, Basics) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(root.WithOpName("A"), DT_FLOAT, 0);
+ auto b = ops::_Arg(root.WithOpName("B"), DT_FLOAT, 1);
+ auto c = ops::_Arg(root.WithOpName("C"), DT_FLOAT, 2);
+ auto d = ops::Add(root.WithOpName("D"), a, b);
+ auto e = ops::Add(root.WithOpName("b"), d, c);
+ auto f = ops::Neg(root.WithOpName("h"), e);
+ auto g =
+ ops::AddN(root.WithOpName("G"), std::initializer_list<ops::Output>{e, f});
+ auto h = ops::_Retval(root.WithOpName("H"), g, 0);
+
+ GraphDef graph_def;
+ root.ToGraphDef(&graph_def);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphConstructorOptions options;
+ TF_EXPECT_OK(ConvertGraphDefToGraph(options, graph_def, graph.get()));
+
+ FunctionDef fdef;
+ TF_EXPECT_OK(GraphToFunctionDef(*graph, "test_fn", &fdef));
+
+ FunctionDef fdef_expected = FunctionDefHelper::Create(
+ "test_fn", // function name
+ {"a: float", "b: float", "c: float"}, // inputs
+ {"h_0: float"}, // outputs
+ {}, // attrs
+ {
+ // nodes in the function body
+ {{"D"}, "Add", {"a", "b"}, {{"T", DT_FLOAT}}},
+ {{"b_0"}, "Add", {"D:z:0", "c"}, {{"T", DT_FLOAT}}},
+ {{"h"}, "Neg", {"b_0:z:0"}, {{"T", DT_FLOAT}}},
+ {{"G"}, "AddN", {"b_0:z:0", "h:y:0"}, {{"N", 2}, {"T", DT_FLOAT}}},
+ },
+ {{"h_0", "G:sum:0"}}); // return values
+
+ string diff;
+ bool fdefs_equal = EqualFunctionDef(fdef_expected, fdef, &diff);
+ EXPECT_TRUE(fdefs_equal) << diff;
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/graphcycles/BUILD b/tensorflow/compiler/jit/graphcycles/BUILD
new file mode 100644
index 0000000000..ce634529cb
--- /dev/null
+++ b/tensorflow/compiler/jit/graphcycles/BUILD
@@ -0,0 +1,41 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//tensorflow/compiler/tf2xla:internal",
+ ],
+)
+
+cc_library(
+ name = "graphcycles",
+ srcs = ["graphcycles.cc"],
+ hdrs = ["graphcycles.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "graphcycles_test",
+ srcs = ["graphcycles_test.cc"],
+ deps = [
+ ":graphcycles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.cc b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
new file mode 100644
index 0000000000..87d5de09d1
--- /dev/null
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles.cc
@@ -0,0 +1,391 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// GraphCycles provides incremental cycle detection on a dynamic
+// graph using the following algorithm:
+//
+// A dynamic topological sort algorithm for directed acyclic graphs
+// David J. Pearce, Paul H. J. Kelly
+// Journal of Experimental Algorithmics (JEA) JEA Homepage archive
+// Volume 11, 2006, Article No. 1.7
+//
+// Brief summary of the algorithm:
+//
+// (1) Maintain a rank for each node that is consistent
+// with the topological sort of the graph. I.e., path from x to y
+// implies rank[x] < rank[y].
+// (2) When a new edge (x->y) is inserted, do nothing if rank[x] < rank[y].
+// (3) Otherwise: adjust ranks in the neighborhood of x and y.
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+
+#include <algorithm>
+#include <unordered_set>
+
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+namespace {
+
+typedef std::unordered_set<int32> NodeSet;
+template <typename T>
+struct VecStruct {
+ typedef gtl::InlinedVector<T, 4> type;
+};
+template <typename T>
+using Vec = typename VecStruct<T>::type;
+
+struct Node {
+ Node() : in(4), out(4) {} // Small hashtables for in/out edges
+
+ int32 rank; // rank number assigned by Pearce-Kelly algorithm
+ bool visited; // Temporary marker used by depth-first-search
+ void* data; // User-supplied data
+ NodeSet in; // List of immediate predecessor nodes in graph
+ NodeSet out; // List of immediate successor nodes in graph
+};
+
+} // namespace
+
+struct GraphCycles::Rep {
+ Vec<Node*> nodes_;
+ Vec<int32> free_nodes_; // Indices for unused entries in nodes_
+
+ // Temporary state.
+ Vec<int32> deltaf_; // Results of forward DFS
+ Vec<int32> deltab_; // Results of backward DFS
+ Vec<int32> list_; // All nodes to reprocess
+ Vec<int32> merged_; // Rank values to assign to list_ entries
+ Vec<int32> stack_; // Emulates recursion stack when doing depth first search
+};
+
+GraphCycles::GraphCycles() : rep_(new Rep) {}
+
+GraphCycles::~GraphCycles() {
+ for (int i = 0; i < rep_->nodes_.size(); i++) {
+ delete rep_->nodes_[i];
+ }
+ delete rep_;
+}
+
+bool GraphCycles::CheckInvariants() const {
+ Rep* r = rep_;
+ NodeSet ranks; // Set of ranks seen so far.
+ for (int32 x = 0; x < r->nodes_.size(); x++) {
+ Node* nx = r->nodes_[x];
+ if (nx->visited) {
+ LOG(FATAL) << "Did not clear visited marker on node " << x;
+ }
+ if (!ranks.insert(nx->rank).second) {
+ LOG(FATAL) << "Duplicate occurrence of rank " << nx->rank;
+ }
+ for (auto y : nx->out) {
+ Node* ny = r->nodes_[y];
+ if (nx->rank >= ny->rank) {
+ LOG(FATAL) << "Edge " << x << "->" << y << " has bad rank assignment "
+ << nx->rank << "->" << ny->rank;
+ }
+ }
+ }
+ return true;
+}
+
+int32 GraphCycles::NewNode() {
+ if (rep_->free_nodes_.empty()) {
+ Node* n = new Node;
+ n->visited = false;
+ n->data = NULL;
+ n->rank = rep_->nodes_.size();
+ rep_->nodes_.push_back(n);
+ return n->rank;
+ } else {
+ // Preserve preceding rank since the set of ranks in use must be
+ // a permutation of [0,rep_->nodes_.size()-1].
+ int32 r = rep_->free_nodes_.back();
+ rep_->nodes_[r]->data = NULL;
+ rep_->free_nodes_.pop_back();
+ return r;
+ }
+}
+
+void GraphCycles::RemoveNode(int32 node) {
+ Node* x = rep_->nodes_[node];
+ for (auto y : x->out) {
+ rep_->nodes_[y]->in.erase(node);
+ }
+ for (auto y : x->in) {
+ rep_->nodes_[y]->out.erase(node);
+ }
+ x->in.clear();
+ x->out.clear();
+ rep_->free_nodes_.push_back(node);
+}
+
+void* GraphCycles::GetNodeData(int32 node) const {
+ return rep_->nodes_[node]->data;
+}
+
+void GraphCycles::SetNodeData(int32 node, void* data) {
+ rep_->nodes_[node]->data = data;
+}
+
+bool GraphCycles::HasEdge(int32 x, int32 y) const {
+ return rep_->nodes_[x]->out.find(y) != rep_->nodes_[x]->out.end();
+}
+
+void GraphCycles::RemoveEdge(int32 x, int32 y) {
+ rep_->nodes_[x]->out.erase(y);
+ rep_->nodes_[y]->in.erase(x);
+ // No need to update the rank assignment since a previous valid
+ // rank assignment remains valid after an edge deletion.
+}
+
+static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound);
+static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound);
+static void Reorder(GraphCycles::Rep* r);
+static void Sort(const Vec<Node*>&, Vec<int32>* delta);
+static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst);
+static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes);
+
+bool GraphCycles::InsertEdge(int32 x, int32 y) {
+ if (x == y) return false;
+ Rep* r = rep_;
+ Node* nx = r->nodes_[x];
+ if (!nx->out.insert(y).second) {
+ // Edge already exists.
+ return true;
+ }
+
+ Node* ny = r->nodes_[y];
+ ny->in.insert(x);
+
+ if (nx->rank <= ny->rank) {
+ // New edge is consistent with existing rank assignment.
+ return true;
+ }
+
+ // Current rank assignments are incompatible with the new edge. Recompute.
+ // We only need to consider nodes that fall in the range [ny->rank,nx->rank].
+ if (!ForwardDFS(r, y, nx->rank)) {
+ // Found a cycle. Undo the insertion and tell caller.
+ nx->out.erase(y);
+ ny->in.erase(x);
+ // Since we do not call Reorder() on this path, clear any visited
+ // markers left by ForwardDFS.
+ ClearVisitedBits(r, r->deltaf_);
+ return false;
+ }
+ BackwardDFS(r, x, ny->rank);
+ Reorder(r);
+ return true;
+}
+
+static bool ForwardDFS(GraphCycles::Rep* r, int32 n, int32 upper_bound) {
+ // Avoid recursion since stack space might be limited.
+ // We instead keep a stack of nodes to visit.
+ r->deltaf_.clear();
+ r->stack_.clear();
+ r->stack_.push_back(n);
+ while (!r->stack_.empty()) {
+ n = r->stack_.back();
+ r->stack_.pop_back();
+ Node* nn = r->nodes_[n];
+ if (nn->visited) continue;
+
+ nn->visited = true;
+ r->deltaf_.push_back(n);
+
+ for (auto w : nn->out) {
+ Node* nw = r->nodes_[w];
+ if (nw->rank == upper_bound) {
+ return false; // Cycle
+ }
+ if (!nw->visited && nw->rank < upper_bound) {
+ r->stack_.push_back(w);
+ }
+ }
+ }
+ return true;
+}
+
+static void BackwardDFS(GraphCycles::Rep* r, int32 n, int32 lower_bound) {
+ r->deltab_.clear();
+ r->stack_.clear();
+ r->stack_.push_back(n);
+ while (!r->stack_.empty()) {
+ n = r->stack_.back();
+ r->stack_.pop_back();
+ Node* nn = r->nodes_[n];
+ if (nn->visited) continue;
+
+ nn->visited = true;
+ r->deltab_.push_back(n);
+
+ for (auto w : nn->in) {
+ Node* nw = r->nodes_[w];
+ if (!nw->visited && lower_bound < nw->rank) {
+ r->stack_.push_back(w);
+ }
+ }
+ }
+}
+
+static void Reorder(GraphCycles::Rep* r) {
+ Sort(r->nodes_, &r->deltab_);
+ Sort(r->nodes_, &r->deltaf_);
+
+ // Adds contents of delta lists to list_ (backwards deltas first).
+ r->list_.clear();
+ MoveToList(r, &r->deltab_, &r->list_);
+ MoveToList(r, &r->deltaf_, &r->list_);
+
+ // Produce sorted list of all ranks that will be reassigned.
+ r->merged_.resize(r->deltab_.size() + r->deltaf_.size());
+ std::merge(r->deltab_.begin(), r->deltab_.end(), r->deltaf_.begin(),
+ r->deltaf_.end(), r->merged_.begin());
+
+ // Assign the ranks in order to the collected list.
+ for (int32 i = 0; i < r->list_.size(); i++) {
+ r->nodes_[r->list_[i]]->rank = r->merged_[i];
+ }
+}
+
+static void Sort(const Vec<Node*>& nodes, Vec<int32>* delta) {
+ struct ByRank {
+ const Vec<Node*>* nodes;
+ bool operator()(int32 a, int32 b) const {
+ return (*nodes)[a]->rank < (*nodes)[b]->rank;
+ }
+ };
+ ByRank cmp;
+ cmp.nodes = &nodes;
+ std::sort(delta->begin(), delta->end(), cmp);
+}
+
+static void MoveToList(GraphCycles::Rep* r, Vec<int32>* src, Vec<int32>* dst) {
+ for (int32 i = 0; i < src->size(); i++) {
+ int32 w = (*src)[i];
+ (*src)[i] = r->nodes_[w]->rank; // Replace src entry with its rank
+ r->nodes_[w]->visited = false; // Prepare for future DFS calls
+ dst->push_back(w);
+ }
+}
+
+static void ClearVisitedBits(GraphCycles::Rep* r, const Vec<int32>& nodes) {
+ for (int32 i = 0; i < nodes.size(); i++) {
+ r->nodes_[nodes[i]]->visited = false;
+ }
+}
+
+int GraphCycles::FindPath(int32 x, int32 y, int max_path_len,
+ int32 path[]) const {
+ // Forward depth first search starting at x until we hit y.
+ // As we descend into a node, we push it onto the path.
+ // As we leave a node, we remove it from the path.
+ int path_len = 0;
+
+ Rep* r = rep_;
+ NodeSet seen;
+ r->stack_.clear();
+ r->stack_.push_back(x);
+ while (!r->stack_.empty()) {
+ int32 n = r->stack_.back();
+ r->stack_.pop_back();
+ if (n < 0) {
+ // Marker to indicate that we are leaving a node
+ path_len--;
+ continue;
+ }
+
+ if (path_len < max_path_len) {
+ path[path_len] = n;
+ }
+ path_len++;
+ r->stack_.push_back(-1); // Will remove tentative path entry
+
+ if (n == y) {
+ return path_len;
+ }
+
+ for (auto w : r->nodes_[n]->out) {
+ if (seen.insert(w).second) {
+ r->stack_.push_back(w);
+ }
+ }
+ }
+
+ return 0;
+}
+
+bool GraphCycles::IsReachable(int32 x, int32 y) const {
+ return FindPath(x, y, 0, NULL) > 0;
+}
+
+bool GraphCycles::IsReachableNonConst(int32 x, int32 y) {
+ if (x == y) return true;
+ Rep* r = rep_;
+ Node* nx = r->nodes_[x];
+ Node* ny = r->nodes_[y];
+
+ if (nx->rank >= ny->rank) {
+ // x cannot reach y since it is after it in the topological ordering
+ return false;
+ }
+
+ // See if x can reach y using a DFS search that is limited to y's rank
+ bool reachable = !ForwardDFS(r, x, ny->rank);
+
+ // Clear any visited markers left by ForwardDFS.
+ ClearVisitedBits(r, r->deltaf_);
+ return reachable;
+}
+
+bool GraphCycles::ContractEdge(int32 a, int32 b) {
+ CHECK(HasEdge(a, b));
+ RemoveEdge(a, b);
+
+ if (IsReachableNonConst(a, b)) {
+ // Restore the graph to its original state.
+ InsertEdge(a, b);
+ return false;
+ }
+
+ Node* nb = rep_->nodes_[b];
+ std::unordered_set<int32> out = std::move(nb->out);
+ std::unordered_set<int32> in = std::move(nb->in);
+ for (auto y : out) {
+ rep_->nodes_[y]->in.erase(b);
+ }
+ for (auto y : in) {
+ rep_->nodes_[y]->out.erase(b);
+ }
+ rep_->free_nodes_.push_back(b);
+
+ for (auto y : out) {
+ InsertEdge(a, y);
+ }
+ for (auto y : in) {
+ InsertEdge(y, a);
+ }
+ return true;
+}
+
+std::unordered_set<int32> GraphCycles::Successors(int32 node) {
+ return rep_->nodes_[node]->out;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles.h b/tensorflow/compiler/jit/graphcycles/graphcycles.h
new file mode 100644
index 0000000000..d11d6e27b1
--- /dev/null
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles.h
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
+#define TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
+
+// GraphCycles detects the introduction of a cycle into a directed
+// graph that is being built up incrementally.
+//
+// Nodes are identified by small integers. It is not possible to
+// record multiple edges with the same (source, destination) pair;
+// requests to add an edge where one already exists are silently
+// ignored.
+//
+// It is also not possible to introduce a cycle; an attempt to insert
+// an edge that would introduce a cycle fails and returns false.
+//
+// GraphCycles uses no internal locking; calls into it should be
+// serialized externally.
+
+// Performance considerations:
+// Works well on sparse graphs, poorly on dense graphs.
+// Extra information is maintained incrementally to detect cycles quickly.
+// InsertEdge() is very fast when the edge already exists, and reasonably fast
+// otherwise.
+// FindPath() is linear in the size of the graph.
+// The current implementation uses O(|V|+|E|) space.
+
+#include <unordered_set>
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// NOTE!!!
+// For now a copy of this is forked to net/plaque. If you
+// find a bug or add a feature, please inform the owners of the
+// net/plaque copy in case it should be integrated.
+// NOTE!!!
+class GraphCycles {
+ public:
+ GraphCycles();
+ ~GraphCycles();
+
+ // Allocate an unused node id and return it.
+ // The new node has a null pointer for its node data.
+ // All node identifiers passed to other routines in this interface
+ // must have been allocated by NewNode() and not yet deallocated
+ // by RemoveNode().
+ int32 NewNode();
+
+ // Remove "node" from the graph, deleting all edges to and from it.
+ // After this call the identifier "node" it may no longer be used
+ // as an argument to any routine until it has been reallocated with
+ // NewNode().
+ void RemoveNode(int32 node);
+
+ // Attempt to insert an edge from source_node to dest_node. If the
+ // edge would introduce a cycle, return false without making any
+ // changes. Otherwise add the edge and return true.
+ bool InsertEdge(int32 source_node, int32 dest_node);
+
+ // Remove any edge that exists from source_node to dest_node.
+ void RemoveEdge(int32 source_node, int32 dest_node);
+
+ // Return whether there is an edge directly from source_node to dest_node.
+ bool HasEdge(int32 source_node, int32 dest_node) const;
+
+ // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. 'b' is
+ // removed from the graph, and edges to/from 'b' are replaced with edges
+ // to/from 'a'. If contracting the edge would create a cycle, does nothing
+ // and returns false.
+ bool ContractEdge(int32 a, int32 b);
+
+ // Return whether dest_node is reachable from source_node
+ // by following edges.
+ bool IsReachable(int32 source_node, int32 dest_node) const;
+
+ // A faster non-thread-safe version of IsReachable.
+ bool IsReachableNonConst(int32 source_node, int32 dest_node);
+
+ // Return or set the node data for a node. This data is unused
+ // by the implementation.
+ void *GetNodeData(int32 node) const;
+ void SetNodeData(int32 node, void *data);
+
+ // Find a path from "source" to "dest". If such a path exists, place the
+ // node IDs of the nodes on the path in the array path[], and return the
+ // number of nodes on the path. If the path is longer than max_path_len
+ // nodes, only the first max_path_len nodes are placed in path[]. The client
+ // should compare the return value with max_path_len" to see when this
+ // occurs. If no path exists, return 0. Any valid path stored in path[]
+ // will start with "source" and end with "dest". There is no guarantee that
+ // the path is the shortest, but no node will appear twice in the path,
+ // except the source and destination node if they are identical; therefore,
+ // the return value is at most one greater than the number of nodes in the
+ // graph.
+ int FindPath(int32 source, int32 dest, int max_path_len, int32 path[]) const;
+
+ // Check internal invariants. Crashes on failure, returns true on success.
+ // Expensive: should only be called from graphcycles_test.cc.
+ bool CheckInvariants() const;
+
+ std::unordered_set<int32> Successors(int32 node);
+
+ // ----------------------------------------------------
+ struct Rep;
+
+ private:
+ Rep *rep_; // opaque representation
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphCycles);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_COMPILER_JIT_GRAPHCYCLES_GRAPHCYCLES_H_
diff --git a/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
new file mode 100644
index 0000000000..f27a616ac9
--- /dev/null
+++ b/tensorflow/compiler/jit/graphcycles/graphcycles_test.cc
@@ -0,0 +1,515 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A test for the GraphCycles interface.
+
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+
+#include <random>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int32;
+using tensorflow::string;
+
+// We emulate a GraphCycles object with a node vector and an edge vector.
+// We then compare the two implementations.
+
+typedef std::vector<int> Nodes;
+struct Edge {
+ int from;
+ int to;
+};
+typedef std::vector<Edge> Edges;
+
+// Return whether "to" is reachable from "from".
+static bool IsReachable(Edges *edges, int from, int to,
+ std::unordered_set<int> *seen) {
+ seen->insert(from); // we are investigating "from"; don't do it again
+ if (from == to) return true;
+ for (int i = 0; i != edges->size(); i++) {
+ Edge *edge = &(*edges)[i];
+ if (edge->from == from) {
+ if (edge->to == to) { // success via edge directly
+ return true;
+ } else if (seen->find(edge->to) == seen->end() && // success via edge
+ IsReachable(edges, edge->to, to, seen)) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+static void PrintNodes(Nodes *nodes) {
+ LOG(INFO) << "NODES (" << nodes->size() << ")";
+ for (int i = 0; i != nodes->size(); i++) {
+ LOG(INFO) << (*nodes)[i];
+ }
+}
+
+static void PrintEdges(Edges *edges) {
+ LOG(INFO) << "EDGES (" << edges->size() << ")";
+ for (int i = 0; i != edges->size(); i++) {
+ int a = (*edges)[i].from;
+ int b = (*edges)[i].to;
+ LOG(INFO) << a << " " << b;
+ }
+ LOG(INFO) << "---";
+}
+
+static void PrintGCEdges(Nodes *nodes, tensorflow::GraphCycles *gc) {
+ LOG(INFO) << "GC EDGES";
+ for (int i = 0; i != nodes->size(); i++) {
+ for (int j = 0; j != nodes->size(); j++) {
+ int a = (*nodes)[i];
+ int b = (*nodes)[j];
+ if (gc->HasEdge(a, b)) {
+ LOG(INFO) << a << " " << b;
+ }
+ }
+ }
+ LOG(INFO) << "---";
+}
+
+static void PrintTransitiveClosure(Nodes *nodes, Edges *edges,
+ tensorflow::GraphCycles *gc) {
+ LOG(INFO) << "Transitive closure";
+ for (int i = 0; i != nodes->size(); i++) {
+ for (int j = 0; j != nodes->size(); j++) {
+ int a = (*nodes)[i];
+ int b = (*nodes)[j];
+ std::unordered_set<int> seen;
+ if (IsReachable(edges, a, b, &seen)) {
+ LOG(INFO) << a << " " << b;
+ }
+ }
+ }
+ LOG(INFO) << "---";
+}
+
+static void PrintGCTransitiveClosure(Nodes *nodes,
+ tensorflow::GraphCycles *gc) {
+ LOG(INFO) << "GC Transitive closure";
+ for (int i = 0; i != nodes->size(); i++) {
+ for (int j = 0; j != nodes->size(); j++) {
+ int a = (*nodes)[i];
+ int b = (*nodes)[j];
+ if (gc->IsReachable(a, b)) {
+ LOG(INFO) << a << " " << b;
+ }
+ }
+ }
+ LOG(INFO) << "---";
+}
+
+static void CheckTransitiveClosure(Nodes *nodes, Edges *edges,
+ tensorflow::GraphCycles *gc) {
+ std::unordered_set<int> seen;
+ for (int i = 0; i != nodes->size(); i++) {
+ for (int j = 0; j != nodes->size(); j++) {
+ seen.clear();
+ int a = (*nodes)[i];
+ int b = (*nodes)[j];
+ bool gc_reachable = gc->IsReachable(a, b);
+ CHECK_EQ(gc_reachable, gc->IsReachableNonConst(a, b));
+ bool reachable = IsReachable(edges, a, b, &seen);
+ if (gc_reachable != reachable) {
+ PrintEdges(edges);
+ PrintGCEdges(nodes, gc);
+ PrintTransitiveClosure(nodes, edges, gc);
+ PrintGCTransitiveClosure(nodes, gc);
+ LOG(FATAL) << "gc_reachable " << gc_reachable << " reachable "
+ << reachable << " a " << a << " b " << b;
+ }
+ }
+ }
+}
+
+static void CheckEdges(Nodes *nodes, Edges *edges,
+ tensorflow::GraphCycles *gc) {
+ int count = 0;
+ for (int i = 0; i != edges->size(); i++) {
+ int a = (*edges)[i].from;
+ int b = (*edges)[i].to;
+ if (!gc->HasEdge(a, b)) {
+ PrintEdges(edges);
+ PrintGCEdges(nodes, gc);
+ LOG(FATAL) << "!gc->HasEdge(" << a << ", " << b << ")";
+ }
+ }
+ for (int i = 0; i != nodes->size(); i++) {
+ for (int j = 0; j != nodes->size(); j++) {
+ int a = (*nodes)[i];
+ int b = (*nodes)[j];
+ if (gc->HasEdge(a, b)) {
+ count++;
+ }
+ }
+ }
+ if (count != edges->size()) {
+ PrintEdges(edges);
+ PrintGCEdges(nodes, gc);
+ LOG(FATAL) << "edges->size() " << edges->size() << " count " << count;
+ }
+}
+
+// Returns the index of a randomly chosen node in *nodes.
+// Requires *nodes be non-empty.
+static int RandomNode(std::mt19937 *rnd, Nodes *nodes) {
+ std::uniform_int_distribution<int> distribution(0, nodes->size() - 1);
+ return distribution(*rnd);
+}
+
+// Returns the index of a randomly chosen edge in *edges.
+// Requires *edges be non-empty.
+static int RandomEdge(std::mt19937 *rnd, Edges *edges) {
+ std::uniform_int_distribution<int> distribution(0, edges->size() - 1);
+ return distribution(*rnd);
+}
+
+// Returns the index of edge (from, to) in *edges or -1 if it is not in *edges.
+static int EdgeIndex(Edges *edges, int from, int to) {
+ int i = 0;
+ while (i != edges->size() &&
+ ((*edges)[i].from != from || (*edges)[i].to != to)) {
+ i++;
+ }
+ return i == edges->size() ? -1 : i;
+}
+
+TEST(GraphCycles, RandomizedTest) {
+ Nodes nodes;
+ Edges edges; // from, to
+ tensorflow::GraphCycles graph_cycles;
+ static const int kMaxNodes = 7; // use <= 7 nodes to keep test short
+ static const int kDataOffset = 17; // an offset to the node-specific data
+ int n = 100000;
+ int op = 0;
+ std::mt19937 rnd(tensorflow::testing::RandomSeed() + 1);
+
+ for (int iter = 0; iter != n; iter++) {
+ if ((iter % 10000) == 0) VLOG(0) << "Iter " << iter << " of " << n;
+
+ if (VLOG_IS_ON(3)) {
+ LOG(INFO) << "===============";
+ LOG(INFO) << "last op " << op;
+ PrintNodes(&nodes);
+ PrintEdges(&edges);
+ PrintGCEdges(&nodes, &graph_cycles);
+ }
+ for (int i = 0; i != nodes.size(); i++) {
+ ASSERT_EQ(reinterpret_cast<intptr_t>(graph_cycles.GetNodeData(i)),
+ i + kDataOffset)
+ << " node " << i;
+ }
+ CheckEdges(&nodes, &edges, &graph_cycles);
+ CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
+ std::uniform_int_distribution<int> distribution(0, 5);
+ op = distribution(rnd);
+ switch (op) {
+ case 0: // Add a node
+ if (nodes.size() < kMaxNodes) {
+ int new_node = graph_cycles.NewNode();
+ ASSERT_NE(-1, new_node);
+ VLOG(1) << "adding node " << new_node;
+ ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
+ graph_cycles.SetNodeData(
+ new_node, reinterpret_cast<void *>(
+ static_cast<intptr_t>(new_node + kDataOffset)));
+ ASSERT_GE(new_node, 0);
+ for (int i = 0; i != nodes.size(); i++) {
+ ASSERT_NE(nodes[i], new_node);
+ }
+ nodes.push_back(new_node);
+ }
+ break;
+
+ case 1: // Remove a node
+ if (nodes.size() > 0) {
+ int node_index = RandomNode(&rnd, &nodes);
+ int node = nodes[node_index];
+ nodes[node_index] = nodes.back();
+ nodes.pop_back();
+ VLOG(1) << "removing node " << node;
+ graph_cycles.RemoveNode(node);
+ int i = 0;
+ while (i != edges.size()) {
+ if (edges[i].from == node || edges[i].to == node) {
+ edges[i] = edges.back();
+ edges.pop_back();
+ } else {
+ i++;
+ }
+ }
+ }
+ break;
+
+ case 2: // Add an edge
+ if (nodes.size() > 0) {
+ int from = RandomNode(&rnd, &nodes);
+ int to = RandomNode(&rnd, &nodes);
+ if (EdgeIndex(&edges, nodes[from], nodes[to]) == -1) {
+ if (graph_cycles.InsertEdge(nodes[from], nodes[to])) {
+ Edge new_edge;
+ new_edge.from = nodes[from];
+ new_edge.to = nodes[to];
+ edges.push_back(new_edge);
+ } else {
+ std::unordered_set<int> seen;
+ ASSERT_TRUE(IsReachable(&edges, nodes[to], nodes[from], &seen))
+ << "Edge " << nodes[to] << "->" << nodes[from];
+ }
+ }
+ }
+ break;
+
+ case 3: // Remove an edge
+ if (edges.size() > 0) {
+ int i = RandomEdge(&rnd, &edges);
+ int from = edges[i].from;
+ int to = edges[i].to;
+ ASSERT_EQ(i, EdgeIndex(&edges, from, to));
+ edges[i] = edges.back();
+ edges.pop_back();
+ ASSERT_EQ(-1, EdgeIndex(&edges, from, to));
+ VLOG(1) << "removing edge " << from << " " << to;
+ graph_cycles.RemoveEdge(from, to);
+ }
+ break;
+
+ case 4: // Check a path
+ if (nodes.size() > 0) {
+ int from = RandomNode(&rnd, &nodes);
+ int to = RandomNode(&rnd, &nodes);
+ int32 path[2 * kMaxNodes];
+ int path_len = graph_cycles.FindPath(nodes[from], nodes[to],
+ 2 * kMaxNodes, path);
+ std::unordered_set<int> seen;
+ bool reachable = IsReachable(&edges, nodes[from], nodes[to], &seen);
+ bool gc_reachable = graph_cycles.IsReachable(nodes[from], nodes[to]);
+ ASSERT_EQ(gc_reachable,
+ graph_cycles.IsReachableNonConst(nodes[from], nodes[to]));
+ ASSERT_EQ(path_len != 0, reachable);
+ ASSERT_EQ(path_len != 0, gc_reachable);
+ // In the following line, we add one because a node can appear
+ // twice, if the path is from that node to itself, perhaps via
+ // every other node.
+ ASSERT_LE(path_len, kMaxNodes + 1);
+ if (path_len != 0) {
+ ASSERT_EQ(nodes[from], path[0]);
+ ASSERT_EQ(nodes[to], path[path_len - 1]);
+ for (int i = 1; i < path_len; i++) {
+ ASSERT_NE(-1, EdgeIndex(&edges, path[i - 1], path[i]));
+ ASSERT_TRUE(graph_cycles.HasEdge(path[i - 1], path[i]));
+ }
+ }
+ }
+ break;
+
+ case 5: // Check invariants
+ CHECK(graph_cycles.CheckInvariants());
+ break;
+
+ default:
+ LOG(FATAL);
+ }
+
+ // Very rarely, test graph expansion by adding then removing many nodes.
+ std::bernoulli_distribution rarely(1.0 / 1024.0);
+ if (rarely(rnd)) {
+ VLOG(3) << "Graph expansion";
+ CheckEdges(&nodes, &edges, &graph_cycles);
+ CheckTransitiveClosure(&nodes, &edges, &graph_cycles);
+ for (int i = 0; i != 256; i++) {
+ int new_node = graph_cycles.NewNode();
+ ASSERT_NE(-1, new_node);
+ VLOG(1) << "adding node " << new_node;
+ ASSERT_GE(new_node, 0);
+ ASSERT_EQ(0, graph_cycles.GetNodeData(new_node));
+ graph_cycles.SetNodeData(
+ new_node, reinterpret_cast<void *>(
+ static_cast<intptr_t>(new_node + kDataOffset)));
+ for (int j = 0; j != nodes.size(); j++) {
+ ASSERT_NE(nodes[j], new_node);
+ }
+ nodes.push_back(new_node);
+ }
+ for (int i = 0; i != 256; i++) {
+ ASSERT_GT(nodes.size(), 0);
+ int node_index = RandomNode(&rnd, &nodes);
+ int node = nodes[node_index];
+ nodes[node_index] = nodes.back();
+ nodes.pop_back();
+ VLOG(1) << "removing node " << node;
+ graph_cycles.RemoveNode(node);
+ int j = 0;
+ while (j != edges.size()) {
+ if (edges[j].from == node || edges[j].to == node) {
+ edges[j] = edges.back();
+ edges.pop_back();
+ } else {
+ j++;
+ }
+ }
+ }
+ CHECK(graph_cycles.CheckInvariants());
+ }
+ }
+}
+
+class GraphCyclesTest : public ::testing::Test {
+ public:
+ tensorflow::GraphCycles g_;
+
+ // Test relies on ith NewNode() call returning Node numbered i
+ GraphCyclesTest() {
+ for (int i = 0; i < 100; i++) {
+ CHECK_EQ(i, g_.NewNode());
+ }
+ CHECK(g_.CheckInvariants());
+ }
+
+ bool AddEdge(int x, int y) { return g_.InsertEdge(x, y); }
+
+ void AddMultiples() {
+ // For every node x > 0: add edge to 2*x, 3*x
+ for (int x = 1; x < 25; x++) {
+ EXPECT_TRUE(AddEdge(x, 2 * x)) << x;
+ EXPECT_TRUE(AddEdge(x, 3 * x)) << x;
+ }
+ CHECK(g_.CheckInvariants());
+ }
+
+ string Path(int x, int y) {
+ static const int kPathSize = 5;
+ int32 path[kPathSize];
+ int np = g_.FindPath(x, y, kPathSize, path);
+ string result;
+ for (int i = 0; i < np; i++) {
+ if (i >= kPathSize) {
+ result += " ...";
+ break;
+ }
+ if (!result.empty()) result.push_back(' ');
+ char buf[20];
+ snprintf(buf, sizeof(buf), "%d", path[i]);
+ result += buf;
+ }
+ return result;
+ }
+};
+
+TEST_F(GraphCyclesTest, NoCycle) {
+ AddMultiples();
+ CHECK(g_.CheckInvariants());
+}
+
+TEST_F(GraphCyclesTest, SimpleCycle) {
+ AddMultiples();
+ EXPECT_FALSE(AddEdge(8, 4));
+ EXPECT_EQ("4 8", Path(4, 8));
+ CHECK(g_.CheckInvariants());
+}
+
+TEST_F(GraphCyclesTest, IndirectCycle) {
+ AddMultiples();
+ EXPECT_TRUE(AddEdge(16, 9));
+ CHECK(g_.CheckInvariants());
+ EXPECT_FALSE(AddEdge(9, 2));
+ EXPECT_EQ("2 4 8 16 9", Path(2, 9));
+ CHECK(g_.CheckInvariants());
+}
+
+TEST_F(GraphCyclesTest, LongPath) {
+ ASSERT_TRUE(AddEdge(2, 4));
+ ASSERT_TRUE(AddEdge(4, 6));
+ ASSERT_TRUE(AddEdge(6, 8));
+ ASSERT_TRUE(AddEdge(8, 10));
+ ASSERT_TRUE(AddEdge(10, 12));
+ ASSERT_FALSE(AddEdge(12, 2));
+ EXPECT_EQ("2 4 6 8 10 ...", Path(2, 12));
+ CHECK(g_.CheckInvariants());
+}
+
+TEST_F(GraphCyclesTest, RemoveNode) {
+ ASSERT_TRUE(AddEdge(1, 2));
+ ASSERT_TRUE(AddEdge(2, 3));
+ ASSERT_TRUE(AddEdge(3, 4));
+ ASSERT_TRUE(AddEdge(4, 5));
+ g_.RemoveNode(3);
+ ASSERT_TRUE(AddEdge(5, 1));
+}
+
+TEST_F(GraphCyclesTest, ManyEdges) {
+ const int N = 50;
+ for (int i = 0; i < N; i++) {
+ for (int j = 1; j < N; j++) {
+ ASSERT_TRUE(AddEdge(i, i + j));
+ }
+ }
+ CHECK(g_.CheckInvariants());
+ ASSERT_TRUE(AddEdge(2 * N - 1, 0));
+ CHECK(g_.CheckInvariants());
+ ASSERT_FALSE(AddEdge(10, 9));
+ CHECK(g_.CheckInvariants());
+}
+
+TEST_F(GraphCyclesTest, ContractEdge) {
+ ASSERT_TRUE(AddEdge(1, 2));
+ ASSERT_TRUE(AddEdge(1, 3));
+ ASSERT_TRUE(AddEdge(2, 3));
+ ASSERT_TRUE(AddEdge(2, 4));
+ ASSERT_TRUE(AddEdge(3, 4));
+
+ EXPECT_FALSE(g_.ContractEdge(1, 3));
+ CHECK(g_.CheckInvariants());
+ EXPECT_TRUE(g_.HasEdge(1, 3));
+
+ EXPECT_TRUE(g_.ContractEdge(1, 2));
+ CHECK(g_.CheckInvariants());
+ EXPECT_TRUE(g_.HasEdge(1, 3));
+ EXPECT_TRUE(g_.HasEdge(1, 4));
+ EXPECT_TRUE(g_.HasEdge(3, 4));
+
+ EXPECT_TRUE(g_.ContractEdge(1, 3));
+ CHECK(g_.CheckInvariants());
+ EXPECT_TRUE(g_.HasEdge(1, 4));
+}
+
+static void BM_StressTest(int iters, int num_nodes) {
+ while (iters > 0) {
+ tensorflow::GraphCycles g;
+ int32 *nodes = new int32[num_nodes];
+ for (int i = 0; i < num_nodes; i++) {
+ nodes[i] = g.NewNode();
+ }
+ for (int i = 0; i < num_nodes && iters > 0; i++, iters--) {
+ int end = std::min(num_nodes, i + 5);
+ for (int j = i + 1; j < end; j++) {
+ if (nodes[i] >= 0 && nodes[j] >= 0) {
+ CHECK(g.InsertEdge(nodes[i], nodes[j]));
+ }
+ }
+ }
+ delete[] nodes;
+ }
+}
+BENCHMARK(BM_StressTest)->Range(2048, 1048576);
diff --git a/tensorflow/compiler/jit/jit_compilation_pass_registration.cc b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
new file mode 100644
index 0000000000..4d49a14b24
--- /dev/null
+++ b/tensorflow/compiler/jit/jit_compilation_pass_registration.cc
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/build_xla_launch_ops_pass.h"
+#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+
+namespace tensorflow {
+
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 10,
+ MarkForCompilationPass);
+
+// The EncapsulateSubgraphs pass must run after the MarkForCompilationPass. We
+// also need to run it after the graph been rewritten to have _Send nodes added
+// for fetches. Before the _Send nodes are added, fetch nodes are identified by
+// name, and encapsulation might remove that node from the graph.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 20,
+ EncapsulateSubgraphsPass);
+
+// Must run after EncapsulateSubgraphsPass.
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 30,
+ BuildXlaLaunchOpsPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/BUILD b/tensorflow/compiler/jit/legacy_flags/BUILD
new file mode 100644
index 0000000000..4491dd6ac8
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/BUILD
@@ -0,0 +1,67 @@
+# Legacy command line flags for the XLA bridge libraries.
+
+# Please do not add more flags to this package.
+
+# The XLA bridge libraries were written in an environment that allowed
+# command-line flags to be scattered freely throughout the libraries. This
+# model, while initially convenient, leads to a proliferation in unused command
+# line flags in tests and binaries, and serious problems in servers, where one
+# might wish parameters to be different in independent RPC calls to the same
+# routine.
+#
+# Please don't add more flags. If you're a library author, pass options and
+# parameters explicitly through the library's interface.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+cc_library(
+ name = "encapsulate_subgraphs_pass_flags",
+ srcs = ["encapsulate_subgraphs_pass_flags.cc"],
+ hdrs = ["encapsulate_subgraphs_pass_flags.h"],
+ deps =
+ [
+ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "mark_for_compilation_pass_flags",
+ srcs = ["mark_for_compilation_pass_flags.cc"],
+ hdrs = ["mark_for_compilation_pass_flags.h"],
+ deps =
+ [
+ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "parallel_check_op_flags",
+ srcs = ["parallel_check_op_flags.cc"],
+ hdrs = ["parallel_check_op_flags.h"],
+ deps =
+ [
+ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc
new file mode 100644
index 0000000000..856475f12c
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
+
+#include <mutex>
+#include <vector>
+
+#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static EncapsulateSubgraphsPassFlags* flags;
+static std::vector<Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new EncapsulateSubgraphsPassFlags;
+ flags->tf_xla_parallel_checking = false;
+ flag_list = new std::vector<Flag>({
+ Flag("tf_xla_parallel_checking", &flags->tf_xla_parallel_checking,
+ "Debug tool. Runs both JIT-compiled and interpreted graphs in "
+ "parallel and verifies they produce the same outputs."),
+ });
+ xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with the XLA bridge's
+// encapsulate_subgraphs_pass module.
+void AppendEncapsulateSubgraphsPassFlags(std::vector<Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h
new file mode 100644
index 0000000000..d371bd269d
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
+#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
+
+// Legacy flags for the XLA bridge's encapsulate_subgraphs_pass module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with the XLA bridge's
+// encapsulate_subgraphs_pass module.
+void AppendEncapsulateSubgraphsPassFlags(
+ std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with the XLA bridge's
+// encapsulate_subgraphs_pass module.
+typedef struct {
+ bool tf_xla_parallel_checking; // Debug tool. Runs both JIT-compiled and
+ // interpreted graphs in parallel and verifies
+ // they produce the same outputs.
+} EncapsulateSubgraphsPassFlags;
+
+// Return a pointer to the EncapsulateSubgraphsPassFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+EncapsulateSubgraphsPassFlags* GetEncapsulateSubgraphsPassFlags();
+
+} // namespace legacy_flags
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_ENCAPSULATE_SUBGRAPHS_PASS_FLAGS_H_
diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc
new file mode 100644
index 0000000000..09aee39d8c
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
+
+#include <mutex>
+#include <vector>
+
+#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static MarkForCompilationPassFlags* flags;
+static std::vector<Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new MarkForCompilationPassFlags;
+ flags->tf_xla_auto_jit = 0;
+ flags->tf_xla_min_cluster_size = 2;
+ flags->tf_xla_max_cluster_size = std::numeric_limits<int32>::max();
+ flags->tf_xla_clustering_debug = false;
+ flag_list = new std::vector<Flag>({
+ Flag("tf_xla_auto_jit", &flags->tf_xla_auto_jit,
+ "Control compilation of operators into XLA computations on CPU and "
+ "GPU devices. 0 = use ConfigProto setting; -1 = off; 1 = on for "
+ "things very likely to be improved; 2 = on for everything. "
+ "Experimental."),
+ Flag("tf_xla_min_cluster_size", &flags->tf_xla_min_cluster_size,
+ "Minimum number of operators in an XLA compilation. Ignored for "
+ "operators placed on an XLA device or operators explicitly marked "
+ "for compilation."),
+ Flag("tf_xla_max_cluster_size", &flags->tf_xla_max_cluster_size,
+ "Maximum number of operators in an XLA compilation."),
+ Flag("tf_xla_clustering_debug", &flags->tf_xla_clustering_debug,
+ "Dump graphs during XLA compilation."),
+ });
+ xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with the XLA bridge's
+// mark_for_compilation_pass module.
+void AppendMarkForCompilationPassFlags(std::vector<Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the MarkForCompilationPassFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h
new file mode 100644
index 0000000000..24f8050742
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
+#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
+
+// Legacy flags for the XLA bridge's mark_for_compilation_pass module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with the XLA bridge's
+// mark_for_compilation_pass module.
+void AppendMarkForCompilationPassFlags(
+ std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with the XLA bridge's
+// mark_for_compilation_pass module.
+typedef struct {
+ int32 tf_xla_auto_jit; // Control compilation of operators into XLA
+ // computations on CPU and GPU devices. 0 = use
+ // ConfigProto setting; -1 = off; 1 = on for things
+ // very likely to be improved; 2 = on for everything.
+ // Experimental.
+ int32 tf_xla_min_cluster_size; // Minimum number of operators in an XLA
+ // compilation. Ignored for operators placed
+ // on an XLA device or operators explicitly
+ // marked for compilation.
+ int32 tf_xla_max_cluster_size; // Maximum number of operators in an XLA
+ // compilation.
+ bool tf_xla_clustering_debug; // Dump graphs during XLA compilation.
+} MarkForCompilationPassFlags;
+
+// Return a pointer to the MarkForCompilationPassFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+MarkForCompilationPassFlags* GetMarkForCompilationPassFlags();
+
+} // namespace legacy_flags
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_MARK_FOR_COMPILATION_PASS_FLAGS_H_
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
new file mode 100644
index 0000000000..a61694b494
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for the XLA bridge's parallel_check_op module.
+
+#include <mutex>
+#include <vector>
+
+#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static ParallelCheckOpFlags* flags;
+static std::vector<Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new ParallelCheckOpFlags;
+ flags->parallel_check_failfast = true;
+ flags->parallel_check_atol = "1e-5";
+ flags->parallel_check_rtol = "1e-5";
+ flag_list = new std::vector<Flag>({
+ Flag("parallel_check_failfast", &flags->parallel_check_failfast,
+ "Fail immediately on first parallel-check comparison error."),
+ Flag("parallel_check_atol", &flags->parallel_check_atol,
+ "Absolute error tolerance for parallel-check comparison."),
+ Flag("parallel_check_rtol", &flags->parallel_check_rtol,
+ "Relative error tolerance for parallel-check comparison."),
+ });
+ xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with the XLA bridge's
+// parallel_check_op module.
+void AppendParallelCheckOpFlags(std::vector<Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the ParallelCheckOpFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+ParallelCheckOpFlags* GetParallelCheckOpFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
new file mode 100644
index 0000000000..156a2a2a71
--- /dev/null
+++ b/tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
+#define TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
+
+// Legacy flags for the XLA bridge's parallel_check_op module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with the XLA bridge's
+// parallel_check_op module.
+void AppendParallelCheckOpFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with the XLA bridge's
+// parallel_check_op module.
+typedef struct {
+ bool parallel_check_failfast; // Fail immediately on first parallel-check
+ // comparison error.
+ string parallel_check_atol; // Absolute error tolerance for parallel-check
+ // comparison.
+ string parallel_check_rtol; // Relative error tolerance for parallel-check
+ // comparison.
+} ParallelCheckOpFlags;
+
+// Return a pointer to the ParallelCheckOpFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+ParallelCheckOpFlags* GetParallelCheckOpFlags();
+
+} // namespace legacy_flags
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_LEGACY_FLAGS_PARALLEL_CHECK_OP_FLAGS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
new file mode 100644
index 0000000000..486725f1da
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -0,0 +1,534 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+
+#include <atomic>
+#include <deque>
+#include <limits>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
+#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/memory_types.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/control_flow.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+const char* const kXlaClusterAttr = "_XlaCluster";
+
+namespace {
+
+bool HasXLAKernel(const NodeDef& node_def, DeviceType jit_device_type) {
+ // There is a SymbolicGradient kernel on the XLA_JIT device, but the gradient
+ // is really a kind of function call and will be handled by
+ // IsCompilableCall().
+ if (node_def.op() == "SymbolicGradient") return false;
+ return FindKernelDef(jit_device_type, node_def, nullptr, nullptr).ok();
+}
+
+// Make sure we don't recurse infinitely on recursive functions.
+const int kMaxRecursionDepth = 5;
+
+bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
+ int depth, FunctionLibraryRuntime* lib_runtime);
+
+// Tests whether 'while_def' is a completely compilable loop.
+// Every operator in the condition and body functions must be compilable for a
+// while loop to be compilable.
+bool IsCompilableWhile(const NodeDef& while_def, DeviceType jit_device_type,
+ int depth, FunctionLibraryRuntime* lib_runtime) {
+ VLOG(2) << "Loop marking: " << while_def.op();
+
+ const NameAttrList* name_attr;
+ NodeDef call;
+ Status status;
+ status = GetNodeAttr(while_def, "cond", &name_attr);
+ if (!status.ok()) {
+ VLOG(2) << "Missing 'cond' attribute on While node.";
+ return false;
+ }
+ const string cond_func = name_attr->name();
+ call.set_name("while_cond");
+ call.set_op(cond_func);
+ *call.mutable_attr() = name_attr->attr();
+ if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ VLOG(2) << "Can't compile loop condition: " << cond_func;
+ return false;
+ }
+ status = GetNodeAttr(while_def, "body", &name_attr);
+ if (!status.ok()) {
+ VLOG(2) << "Missing 'body' attribute on While node.";
+ return false;
+ }
+ const string body_func = name_attr->name();
+ call.set_name("while_body");
+ call.set_op(body_func);
+ *call.mutable_attr() = name_attr->attr();
+ if (!IsCompilableCall(call, jit_device_type, depth + 1, lib_runtime)) {
+ VLOG(2) << "Can't compile loop body: " << body_func;
+ return false;
+ }
+ VLOG(2) << "Loop is compilable.";
+ return true;
+}
+
+// Tests whether 'call_def' is a call to a completely compilable function.
+// Every operator in the function must be compilable for a function to be
+// compilable.
+bool IsCompilableCall(const NodeDef& call_def, DeviceType jit_device_type,
+ int depth, FunctionLibraryRuntime* lib_runtime) {
+ VLOG(2) << "Function marking: " << call_def.op();
+
+ if (depth > kMaxRecursionDepth) {
+ VLOG(2) << "Function depth limit exceeded";
+ return false;
+ }
+
+ FunctionLibraryRuntime::Handle handle;
+ Status status =
+ lib_runtime->Instantiate(call_def.op(), call_def.attr(), &handle);
+ if (!status.ok()) {
+ VLOG(2) << "Could not instantiate " << call_def.op() << ": " << status;
+ return false;
+ }
+ const FunctionBody* fbody = lib_runtime->GetFunctionBody(handle);
+ CHECK(fbody);
+
+ for (Node* node : fbody->graph->nodes()) {
+ if (node->IsSource() || node->IsSink()) continue;
+ if (node->def().op() == "_Arg" || node->def().op() == "_Retval") continue;
+ if (node->def().op() == "While") {
+ // Handle functional While loop (not in open source build).
+ return IsCompilableWhile(node->def(), jit_device_type, depth + 1,
+ lib_runtime);
+ }
+ if (!HasXLAKernel(node->def(), jit_device_type) &&
+ !IsCompilableCall(node->def(), jit_device_type, depth + 1,
+ lib_runtime)) {
+ VLOG(2) << "Function marking failed: unsupported op " << node->name()
+ << ": " << node->def().ShortDebugString();
+ return false;
+ }
+ }
+ VLOG(2) << "Function is compilable: " << call_def.op();
+ return true;
+}
+
+// Returns the DeviceType corresponding to 'device'.
+Status DeviceTypeOfDevice(const string& device, DeviceType* device_type) {
+ DeviceNameUtils::ParsedName parsed;
+ if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
+ return errors::Internal("Malformed assigned device '", device, "'");
+ }
+ *device_type = DeviceType(parsed.type);
+ return Status::OK();
+}
+
+Status FindCompilationCandidates(
+ const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
+ const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
+ std::unordered_set<Node*>* candidates) {
+ OptimizerOptions opts;
+ std::unique_ptr<FunctionLibraryRuntime> lib_runtime(NewFunctionLibraryRuntime(
+ nullptr, env, nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts));
+
+ for (Node* node : graph.nodes()) {
+ if (node->IsSource() || node->IsSink()) continue;
+
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceTypeOfDevice(node->assigned_device_name(), &device_type));
+
+ if (is_compilable_fn && !is_compilable_fn(node, device_type)) continue;
+
+ const string* jit_device_name;
+ CHECK(XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device_name,
+ /*requires_jit=*/nullptr));
+ DeviceType jit_device_type(*jit_device_name);
+ if (!HasXLAKernel(node->def(), jit_device_type) &&
+ !IsCompilableCall(node->def(), jit_device_type, 0, lib_runtime.get())) {
+ VLOG(2) << "Compilation rejected node: unsupported op " << node->name()
+ << ": " << node->def().op();
+ continue;
+ }
+ if (node->def().op() == "While" &&
+ !IsCompilableWhile(node->def(), jit_device_type, 0,
+ lib_runtime.get())) {
+ continue;
+ }
+ candidates->insert(node);
+ }
+ return Status::OK();
+}
+
+// Union-Find data structure used to compute clusters. We use our own
+// implementation because we want one key feature: when merging clusters, we
+// need to know which value becomes the representative of the merged clusters.
+// We use the representatives to name nodes in a cycle detection graph, and we
+// need to control which node is named.
+// TODO(phawkins): consider merging this code with union-find implementations
+// in Tensorflow, e.g., in SimplePlacer.
+class Cluster {
+ public:
+ Cluster();
+
+ int Size() { return FindRoot()->size_; }
+
+ // Merges this cluster with 'other'. This cluster's representative becomes
+ // the representative of the merged cluster; the representative of 'other'
+ // is ignored.
+ void Merge(Cluster* other);
+
+ // Each cluster has an associated integer 'representative', initialized to -1
+ // by default.
+ int GetRepresentative() { return FindRoot()->representative_; }
+ void SetRepresentative(int representative) {
+ FindRoot()->representative_ = representative;
+ }
+
+ private:
+ // Finds the root element of the cluster. Performs path compression.
+ Cluster* FindRoot();
+
+ int representative_;
+ int rank_;
+ int size_; // Size of the cluster.
+ Cluster* parent_;
+};
+
+Cluster::Cluster()
+ : representative_(-1), rank_(0), size_(1), parent_(nullptr) {}
+
+void Cluster::Merge(Cluster* other) {
+ Cluster* a = FindRoot();
+ Cluster* b = other->FindRoot();
+ if (a == b) return;
+ if (a->rank_ > b->rank_) {
+ b->parent_ = a;
+ a->size_ += b->size_;
+ return;
+ }
+
+ a->parent_ = b;
+ if (a->rank_ == b->rank_) {
+ b->rank_++;
+ }
+ b->representative_ = a->representative_;
+ b->size_ += a->size_;
+}
+
+Cluster* Cluster::FindRoot() {
+ if (!parent_) return this;
+ // Path compression: update intermediate nodes to point to the root of the
+ // equivalence class.
+ parent_ = parent_->FindRoot();
+ return parent_;
+}
+
+} // anonymous namespace
+
+bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef) {
+ Device* device = flr->device();
+ const string* jit_device_name;
+ CHECK(XlaOpRegistry::GetJitDevice(device->device_type(), &jit_device_name,
+ /*requires_jit=*/nullptr));
+ DeviceType jit_device_type(*jit_device_name);
+ return IsCompilableCall(ndef, jit_device_type, 0, flr);
+}
+
+Status MarkForCompilationPass::Run(
+ const GraphOptimizationPassOptions& options) {
+ // TODO(phawkins): precompute the "GetJitDevice" properties each device ahead
+ // of time.
+ OptimizerOptions::GlobalJitLevel global_jit_level =
+ options.session_options->config.graph_options()
+ .optimizer_options()
+ .global_jit_level();
+ if (global_jit_level == OptimizerOptions::DEFAULT) {
+ // To set compilation to be on by default, change the following line.
+ global_jit_level = OptimizerOptions::OFF;
+ }
+ legacy_flags::MarkForCompilationPassFlags* flags =
+ legacy_flags::GetMarkForCompilationPassFlags();
+ if (flags->tf_xla_auto_jit == -1 ||
+ (1 <= flags->tf_xla_auto_jit && flags->tf_xla_auto_jit <= 2)) {
+ // If the flag tf_xla_auto_jit is a valid, non-zero setting, it overrides
+ // the setting in ConfigProto.
+ global_jit_level =
+ static_cast<OptimizerOptions::GlobalJitLevel>(flags->tf_xla_auto_jit);
+ }
+ const FunctionLibraryDefinition* fld = options.flib_def;
+ auto is_compilable = [global_jit_level, fld](const Node* node,
+ const DeviceType& device_type) {
+ const string* jit_device;
+ bool requires_jit;
+ if (!XlaOpRegistry::GetJitDevice(device_type.type(), &jit_device,
+ &requires_jit)) {
+ return false;
+ }
+ // If this device requires a JIT, we must say yes.
+ if (requires_jit) return true;
+
+ // If there is a _XlaCompile annotation, use its value.
+ bool compile = false;
+ Status status = GetNodeAttr(node->def(), kXlaCompileAttr, &compile);
+ if (status.ok()) return compile;
+
+ status = fld->GetAttr(node->def(), kXlaCompileAttr, &compile);
+ if (status.ok()) return compile;
+
+ // Otherwise use the value of global_jit_level.
+ return global_jit_level > 0;
+ };
+ return RunImpl(options, is_compilable);
+}
+
+// Is 'node' an operator that consumes only the shape of its input, not the
+// data itself?
+static bool IsShapeConsumerOp(const Node& node) {
+ return node.type_string() == "Shape" || node.type_string() == "Rank" ||
+ node.type_string() == "Size";
+}
+
+// Sequence number generator to ensure clusters have unique names.
+static std::atomic<int64> cluster_sequence_num;
+
+Status MarkForCompilationPass::RunImpl(
+ const GraphOptimizationPassOptions& options,
+ const std::function<bool(const Node*, const DeviceType&)>&
+ is_compilable_fn) {
+ VLOG(1) << "MarkForCompilationPass::Run";
+
+ // Make sure that kernels have been registered on the JIT device.
+ XlaOpRegistry::RegisterJitKernels();
+
+ Graph* graph = options.graph->get();
+
+ std::unordered_set<Node*> compilation_candidates;
+ TF_RETURN_IF_ERROR(FindCompilationCandidates(
+ *graph, options.flib_def,
+ (options.session_options != nullptr) ? options.session_options->env
+ : Env::Default(),
+ is_compilable_fn, &compilation_candidates));
+
+ GraphCycles cycles;
+ for (int i = 0; i < graph->num_node_ids(); ++i) {
+ // We rely on the node IDs in the cycle detection graph being consecutive
+ // integers starting from 0.
+ CHECK_EQ(i, cycles.NewNode());
+ }
+
+ // Compute the loop structure of the graph.
+ std::vector<ControlFlowInfo> control_flow_info;
+ TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &control_flow_info));
+
+ // The clustering code must avoid adding cycles to the graph to prevent
+ // deadlock. However, the graph may contain loops, which would trigger the
+ // cycle detection code. To handle loops, we alter the structure of the cycle
+ // detection graph, disconnecting each loop from the enclosing graph.
+ // Specifically, we:
+ // * add a new "frame" node for each loop.
+ // * replace edges to "Enter" nodes, and edges from "Exit" nodes with edges
+ // to/from the corresponding frame node. In essence, we collapse the loop
+ // into a single node for the purpose of cycle detection in the enclosing
+ // graph.
+ // * the body of the loop should now be disconnected from the rest of the
+ // graph; we make it acyclic by breaking loop backedges (edges outgoing from
+ // "NextIteration" nodes.
+
+ // Map from frame name strings to node IDs in the cycle detection graph.
+ std::unordered_map<string, int> frame_nodes;
+
+ // Get the cycle graph node ID for frame 'frame_name', or add one if none
+ // exists.
+ auto GetOrAddFrameNodeId = [&frame_nodes, &cycles](const string& frame_name) {
+ int& frame_id = frame_nodes.emplace(frame_name, -1).first->second;
+ if (frame_id < 0) {
+ // The emplace succeeded; we have not allocated a frame node yet.
+ frame_id = cycles.NewNode();
+ }
+ return frame_id;
+ };
+
+ for (Edge const* edge : graph->edges()) {
+ if (edge->dst()->IsEnter()) {
+ // Lift edges to an "Enter" node to the corresponding frame node.
+ const string& frame_name =
+ control_flow_info[edge->dst()->id()].frame_name;
+ if (!cycles.InsertEdge(edge->src()->id(),
+ GetOrAddFrameNodeId(frame_name))) {
+ return errors::Internal("Cycle detected when adding enter->frame edge");
+ }
+ continue;
+ }
+ if (edge->src()->IsExit()) {
+ // Lift edges from an "Exit" node to the corresponding frame node.
+ const string& frame_name =
+ control_flow_info[edge->src()->id()].frame_name;
+ if (!cycles.InsertEdge(GetOrAddFrameNodeId(frame_name),
+ edge->dst()->id())) {
+ return errors::Internal("Cycle detected when adding frame->exit edge");
+ }
+ // Drop the original edge.
+ continue;
+ }
+ if (edge->src()->IsNextIteration()) {
+ // Break loop back-edges.
+ continue;
+ }
+ if (!cycles.InsertEdge(edge->src()->id(), edge->dst()->id())) {
+ // This should never happen. All cycles in the graph should contain
+ // a control flow operator.
+ return errors::Internal(
+ "Found cycle in graph without control flow operator during XLA "
+ "compilation.");
+ }
+ }
+
+ // Each compilation candidate belongs to a cluster. The cluster's
+ // representative
+ // names the node in the 'cycles' graph that represents the cluster.
+ std::vector<Cluster> clusters(graph->num_node_ids());
+ std::deque<Cluster*> worklist;
+ for (Node* node : compilation_candidates) {
+ clusters[node->id()].SetRepresentative(node->id());
+ worklist.push_back(&clusters[node->id()]);
+ }
+
+ legacy_flags::MarkForCompilationPassFlags* flags =
+ legacy_flags::GetMarkForCompilationPassFlags();
+
+ // Repeatedly contract edges between clusters that are on the same device,
+ // provided the contraction would not create a cycle.
+ while (!worklist.empty()) {
+ int from = worklist.front()->GetRepresentative();
+ worklist.pop_front();
+
+ Node* node_from = graph->FindNodeId(from);
+ if (node_from->IsControlFlow()) {
+ // Control flow nodes aren't compilation candidates and should never
+ // appear.
+ return errors::Internal("Found control flow node in clustering worklist");
+ }
+ for (int to : cycles.Successors(from)) {
+ if (to >= graph->num_node_ids()) {
+ // Node is a "frame" node that is present only in the cycle detection
+ // graph. No clustering is possible.
+ continue;
+ }
+ Node* node_to = graph->FindNodeId(to);
+ if (compilation_candidates.find(node_to) == compilation_candidates.cend())
+ continue;
+ if (node_from->assigned_device_name() != node_to->assigned_device_name())
+ continue;
+
+ // Ops that consume shapes cannot be the root of a cluster. This is an
+ // optimization.
+ if (clusters[from].Size() == 1 && IsShapeConsumerOp(*node_from)) {
+ continue;
+ }
+
+ // Don't exceed the maximum cluster size.
+ if (clusters[from].Size() + clusters[to].Size() >
+ flags->tf_xla_max_cluster_size) {
+ continue;
+ }
+
+ // If contracting the edge would create a cycle, bail out.
+ // However, just because we can't merge the clusters now does not mean
+ // we won't be able to merge them in the future.
+ // e.g., if we have edges 1->2, 2->3 and 1->3, we cannot contract edge
+ // 1->3. But if we first contract 1->2 then we can later contract 1->3.
+ if (!cycles.ContractEdge(from, to)) continue;
+
+ // Merge the clusters. ContractEdge uses 'from' as the number of the
+ // merged node, so make sure 'from' is the chosen representative.
+ clusters[from].Merge(&clusters[to]);
+
+ worklist.push_back(&clusters[from]);
+ break;
+ }
+ }
+
+ // Count the number of elements in each cluster.
+ std::vector<int> cluster_sizes(graph->num_node_ids());
+ for (const Node* n : compilation_candidates) {
+ int cluster = clusters[n->id()].GetRepresentative();
+ cluster_sizes[cluster]++;
+ }
+
+ // Names for each cluster.
+ std::unordered_map<int, string> cluster_names;
+
+ // Mark clusters for compilation that:
+ // * are placed on a device that requires compilation (an XlaDevice),
+ // * are explicitly marked for compilation (_XlaCompile=true), or
+ // * have more than flags->tf_xla_min_cluster_size elements (applicable only
+ // if compilation is enabled, otherwise there will be no such candidates).
+ const int min_cluster_size = flags->tf_xla_min_cluster_size;
+ for (Node* n : compilation_candidates) {
+ int cluster = clusters[n->id()].GetRepresentative();
+
+ // Compile if the user marked this node _XlaCompile=true
+ bool compile_attr = false;
+ bool marked_for_compilation = false;
+ if (GetNodeAttr(n->def(), kXlaCompileAttr, &compile_attr).ok()) {
+ marked_for_compilation = compile_attr;
+ } else if (options.flib_def
+ ->GetAttr(n->def(), kXlaCompileAttr, &compile_attr)
+ .ok()) {
+ marked_for_compilation = compile_attr;
+ }
+
+ // Compile if this operator is placed on a device that requires
+ // compilation.
+ bool requires_jit = false;
+ DeviceType device_type("");
+ TF_RETURN_IF_ERROR(
+ DeviceTypeOfDevice(n->assigned_device_name(), &device_type));
+ XlaOpRegistry::GetJitDevice(device_type.type(),
+ /*jit_device_name=*/nullptr, &requires_jit);
+
+ // Or compile if this is a cluster of >= min_cluster_size compilable
+ // operators.
+ if (cluster_sizes[cluster] >= min_cluster_size || marked_for_compilation ||
+ requires_jit) {
+ string& name = cluster_names[cluster];
+ if (name.empty()) {
+ name = strings::StrCat("cluster_", cluster_sequence_num++);
+ }
+ n->AddAttr(kXlaClusterAttr, name);
+ }
+ }
+
+ if (flags->tf_xla_clustering_debug) {
+ dump_graph::DumpGraphToFile("mark_for_compilation", **options.graph,
+ options.flib_def);
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h
new file mode 100644
index 0000000000..f91695800f
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// An optimization passes that marks nodes that are to be compiled with
+// attribute kXlaClusterAttr. Nodes with the same cluster ID will be compiled
+// together.
+
+#ifndef TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
+#define TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
+
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+
+namespace tensorflow {
+
+// The attribute that marks nodes to be grouped into functions by the
+// encapsulate subgraphs pass.
+extern const char* const kXlaClusterAttr;
+
+// Pass that marks a subset of operators in the graph with attribute
+// _XlaCluster so they are compiled by the EncapsulateSubgraphsPass.
+class MarkForCompilationPass : public GraphOptimizationPass {
+ public:
+ MarkForCompilationPass() = default;
+
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ // Run() just calls RunImpl() if --tf_xla_auto_jit is enabled. To run the pass
+ // unconditionally, call RunImpl() directly.
+ // is_compilable_fn, if set, is a predicate that must be true for a node to
+ // be compiled.
+ Status RunImpl(const GraphOptimizationPassOptions& options,
+ const std::function<bool(const Node*, const DeviceType&)>&
+ is_compilable_fn = {});
+};
+
+// Returns true iff 'ndef' is a call to a function that is compilable. A
+// function is compilable iff every operator in the function body is
+// compilable.
+bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_MARK_FOR_COMPILATION_PASS_H_
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
new file mode 100644
index 0000000000..560695e87d
--- /dev/null
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -0,0 +1,357 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("UncompilableNullary").Output("o: float");
+REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");
+
+void MarkForCompilation(std::unique_ptr<Graph>* graph,
+ FunctionLibraryDefinition* flib_def) {
+ // Assign all nodes to the CPU device.
+ static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0";
+ for (Node* n : (*graph)->nodes()) {
+ n->set_assigned_device_name(kCpuDevice);
+ }
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = flib_def;
+ MarkForCompilationPass pass;
+ CHECK(pass.RunImpl(opt_options).ok());
+}
+
+void MarkForCompilation(std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+ MarkForCompilation(graph, &flib_def);
+}
+
+std::unordered_map<string, string> GetClusters(const Graph& graph) {
+ std::unordered_map<string, string> ids;
+ for (Node* node : graph.nodes()) {
+ string cluster;
+ if (GetNodeAttr(node->def(), kXlaClusterAttr, &cluster).ok()) {
+ CHECK(!cluster.empty());
+ ids[node->name()] = cluster;
+ }
+ }
+ return ids;
+}
+
+TEST(XlaCompilationTest, Chains) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a =
+ ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
+ Node* d =
+ ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D"));
+ Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E"));
+ ops::UnaryOp("Relu", e, builder.opts().WithName("F"));
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+ EXPECT_EQ(4, clusters.size());
+ EXPECT_EQ(clusters["B"], clusters["C"]);
+ EXPECT_EQ(clusters["E"], clusters["F"]);
+ EXPECT_NE(clusters["B"], clusters["E"]);
+ EXPECT_TRUE(clusters.find("A") == clusters.cend());
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+TEST(XlaCompilationTest, UncompilableCycles) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b =
+ ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B"));
+ ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST(XlaCompilationTest, CompilableCycles) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_EQ(3, clusters.size());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+ EXPECT_EQ(clusters["A"], clusters["C"]);
+}
+
+TEST(XlaCompilationTest, UnsupportedTypes) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp(
+ "Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_COMPLEX64)
+ .WithAttr("value", Tensor(DT_COMPLEX64, TensorShape())));
+ Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B"));
+ ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C"));
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+ EXPECT_TRUE(clusters.empty());
+}
+
+TEST(XlaCompilationTest, ConcatWithConstArg) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ Tensor t(DT_INT32, TensorShape());
+ t.scalar<int32>()() = 0;
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* dim = ops::SourceOp("Const", builder.opts()
+ .WithName("Dim")
+ .WithAttr("dtype", DT_INT32)
+ .WithAttr("value", t));
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", t));
+
+ NodeBuilder concat_builder("Concat", "Concat",
+ builder.opts().op_registry());
+ concat_builder.Input(dim).Input({a, a}).Attr("N", 2);
+ builder.opts().FinalizeBuilder(&concat_builder);
+
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+ EXPECT_EQ(3, clusters.size()); // Everything should be compiled.
+}
+
+TEST(XlaCompilationTest, FunctionCalls) {
+ FunctionDefLibrary flib;
+ *flib.add_function() = FunctionDefHelper::Define(
+ "CompilableFn", {"n_a:float", "n_b:float"}, {"n_c:float"}, {},
+ {{{"n_c"}, "Add", {"n_a", "n_b"}, {{"T", DT_FLOAT}}}});
+ *flib.add_function() =
+ FunctionDefHelper::Define("UncompilableFn", {"n_a:float"}, {"n_c:float"},
+ {}, {{{"n_c"}, "UncompilableUnary", {"n_a"}}});
+ FunctionLibraryDefinition flib_def(OpRegistry::Global(), flib);
+
+ std::unique_ptr<Graph> graph(new Graph(&flib_def));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately, &flib_def);
+ Node* a =
+ ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
+ Node* b = ops::BinaryOp("CompilableFn", a, a, builder.opts().WithName("B"));
+ Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
+ ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D"));
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph, &flib_def);
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_EQ(2, clusters.size());
+ EXPECT_FALSE(clusters["B"].empty());
+ EXPECT_EQ(clusters["B"], clusters["C"]);
+ EXPECT_TRUE(clusters.find("A") == clusters.cend());
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+// Metadata-only operators such as Shape/Rank/Size may not be the root of a
+// cluster. This is partially to work around b/26800664, and partially because
+// we should probably prefer to compile metadata operators with their producers
+// wherever possible, rather than their consumers.
+TEST(XlaCompilationTest, MetadataOpsDontStartClusters) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a =
+ ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
+ // While all of the following ops are notionally compilable, none is
+ // permitted
+ // to start a cluster. So nothing should be compiled.
+ Node* b = ops::UnaryOp("Shape", a, builder.opts().WithName("B"));
+ Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C"));
+ Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D"));
+ ops::UnaryOp("Shape", d, builder.opts().WithName("C"));
+ builder.ToGraph(graph.get());
+ }
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+ EXPECT_EQ(0, clusters.size()); // Nothing should be compiled.
+}
+
+static Status GradForUnaryCwise(FunctionDef* g,
+ std::vector<FunctionDefHelper::Node> nodes) {
+ for (auto& n : nodes) {
+ if (n.attr.empty()) {
+ n.attr = {{"T", DT_FLOAT}};
+ }
+ }
+ *g = FunctionDefHelper::Define(
+ // Arg defs
+ {"x: float", "dy: float"},
+ // Ret val defs
+ {"dx: float"},
+ // Attr defs
+ {},
+ // Nodes
+ nodes);
+ return Status::OK();
+}
+
+// A gradient containing only supported operators
+Status SupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Tanh", {"x"}},
+ {{"y2"}, "Square", {"y"}, {}, {"dy"}},
+ FunctionDefHelper::Const("one", 1.0f),
+ {{"a"}, "Sub", {"one", "y2"}},
+ {{"dx"}, "Mul", {"dy", "a"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Supported", SupportedGrad);
+
+// A gradient containing an unsupported operator.
+Status UnsupportedGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForUnaryCwise(g, {
+ {{"y"}, "Tanh", {"x"}},
+ {{"y2"}, "UncompilableUnary", {"y"}, {}, {"dy"}},
+ FunctionDefHelper::Const("one", 1.0f),
+ {{"a"}, "Sub", {"one", "y2"}},
+ {{"dx"}, "Mul", {"dy", "a"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Unsupported", UnsupportedGrad);
+
+TEST(XlaCompilationTest, SymbolicGradients) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a =
+ ops::SourceOp("UncompilableNullary", builder.opts().WithName("A"));
+
+ // Builds a Symbolic gradient for Supported
+ NodeBuilder b_builder("B", "SymbolicGradient",
+ builder.opts().op_registry());
+ NameAttrList b_name_attr;
+ b_name_attr.set_name("Supported");
+ b_builder.Attr("f", b_name_attr);
+ b_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
+ b_builder.Attr("Tout", {DT_FLOAT});
+ b_builder.Input({a, a});
+ Node* b = builder.opts().FinalizeBuilder(&b_builder);
+
+ Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C"));
+
+ // Builds a Symbolic gradient for Unsupported
+ NodeBuilder d_builder("D", "SymbolicGradient",
+ builder.opts().op_registry());
+ NameAttrList d_name_attr;
+ d_name_attr.set_name("Unsupported");
+ d_builder.Attr("f", d_name_attr);
+ d_builder.Attr("Tin", {DT_FLOAT, DT_FLOAT});
+ d_builder.Attr("Tout", {DT_FLOAT});
+ d_builder.Input({c, c});
+ builder.opts().FinalizeBuilder(&d_builder);
+
+ builder.ToGraph(graph.get());
+ }
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_EQ(2, clusters.size());
+ EXPECT_FALSE(clusters["B"].empty());
+ EXPECT_EQ(clusters["B"], clusters["C"]);
+ EXPECT_TRUE(clusters.find("A") == clusters.cend());
+ EXPECT_TRUE(clusters.find("D") == clusters.cend());
+}
+
+TEST(XlaCompilationTest, Loops) {
+ // Regression test for b/32350199, where the autoclustering code introduced a
+ // deadlock in a graph containing a while loop.
+ Scope root = Scope::NewRootScope().ExitOnError();
+ auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT);
+ auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT);
+ auto c = ops::Add(root.WithOpName("C"), a, b);
+ auto enter = ops::Enter(root, c, "aframe");
+ auto next_iter = ops::NextIteration(root, enter);
+ auto exit = ops::Exit(root, next_iter);
+ auto d = ops::Add(root.WithOpName("D"), c, exit);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ root.ToGraph(graph.get());
+
+ MarkForCompilation(&graph);
+ auto clusters = GetClusters(*graph);
+
+ // Nothing should be compiled. In particular, 'd' and 'c' must not be
+ // compiled.
+ EXPECT_EQ(0, clusters.size());
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/parallel_check_op.cc b/tensorflow/compiler/jit/parallel_check_op.cc
new file mode 100644
index 0000000000..d07da46ca0
--- /dev/null
+++ b/tensorflow/compiler/jit/parallel_check_op.cc
@@ -0,0 +1,154 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/legacy_flags/parallel_check_op_flags.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace {
+
+REGISTER_OP("ParallelCheck")
+ .Attr("T: list(type) >= 0")
+ .Input("expected: T")
+ .Input("actual: T")
+ .Output("result: T")
+ .Doc(R"doc(
+Op that compares two sets of inputs for near-identity, and propagates the first.
+Inequality is logged to ERROR log.
+)doc");
+
+// Inputs 2*N tensors, outputs the first N inputs.
+// Logs errors if input tensor i and i + N are not (near) identical
+// in any position.
+class ParallelCheckOp : public OpKernel {
+ public:
+ explicit ParallelCheckOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ template <typename T>
+ int CompareTensors(DataType dtype, const char* v0, const char* v1,
+ int64 num_elts, int input_idx) {
+ int failed = 0;
+ const T* p0 = reinterpret_cast<const T*>(v0);
+ const T* p1 = reinterpret_cast<const T*>(v1);
+ double rtol;
+ legacy_flags::ParallelCheckOpFlags* flags =
+ legacy_flags::GetParallelCheckOpFlags();
+ if (!tensorflow::strings::safe_strtod(flags->parallel_check_rtol.c_str(),
+ &rtol)) {
+ LOG(ERROR) << "can't convert parallel_check_rtol "
+ << flags->parallel_check_rtol << " to double";
+ }
+ double atol;
+ if (!tensorflow::strings::safe_strtod(flags->parallel_check_atol.c_str(),
+ &atol)) {
+ LOG(ERROR) << "can't convert parallel_check_atol "
+ << flags->parallel_check_atol << " to double";
+ }
+ for (int i = 0; i < num_elts; ++i) {
+ bool ok = (p0[i] == p1[i]);
+ VLOG(2) << "output " << input_idx << " element " << i << ": " << p0[i];
+ if (!ok) {
+ if (std::is_same<T, float>::value || std::is_same<T, double>::value) {
+ float tolerance =
+ std::max(atol, std::max(fabs(rtol * p0[i]), fabs(rtol * p1[i])));
+ T diff = p0[i] - p1[i];
+ if (diff < 0) diff = 0 - diff;
+ ok = (diff <= tolerance);
+ }
+ if (ok) continue;
+ LOG(ERROR) << "Op " << def().name() << " fails equality at output "
+ << input_idx << " type " << DataTypeString(dtype)
+ << " element " << i << ": std_val=" << p0[i]
+ << " test_val=" << p1[i] << " diff=" << (p0[i] - p1[i]);
+ if (++failed > 10) break;
+ }
+ }
+ return failed;
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ VLOG(1) << "Compute " << def().name();
+ const int num_pairs = ctx->num_inputs() / 2;
+ for (int i = 0; i < num_pairs; ++i) {
+ CHECK_EQ(ctx->input_dtype(i), ctx->input_dtype(i + num_pairs));
+ Tensor t0 = ctx->input(i);
+ Tensor t1 = ctx->input(i + num_pairs);
+ int64 num_elts = t0.NumElements();
+ CHECK_EQ(num_elts, t1.NumElements());
+
+ // Compare inputs elementwise for near-exact equality.
+ const char* v0 = t0.tensor_data().data();
+ const char* v1 = t1.tensor_data().data();
+ int failed = 0;
+ switch (ctx->input_dtype(i)) {
+ case DT_INT32:
+ failed =
+ CompareTensors<int32>(ctx->input_dtype(i), v0, v1, num_elts, i);
+ break;
+ case DT_INT64:
+ failed =
+ CompareTensors<int64>(ctx->input_dtype(i), v0, v1, num_elts, i);
+ break;
+ case DT_FLOAT:
+ failed =
+ CompareTensors<float>(ctx->input_dtype(i), v0, v1, num_elts, i);
+ break;
+ case DT_DOUBLE:
+ failed =
+ CompareTensors<double>(ctx->input_dtype(i), v0, v1, num_elts, i);
+ break;
+ case DT_BOOL:
+ failed =
+ CompareTensors<bool>(ctx->input_dtype(i), v0, v1, num_elts, i);
+ break;
+ default:
+ LOG(FATAL) << "unimpl: " << ctx->input_dtype(i);
+ }
+ if (failed > 0) {
+ LOG(ERROR) << "check failed for " << def().name() << " output " << i
+ << " num_elts: " << num_elts;
+ legacy_flags::ParallelCheckOpFlags* flags =
+ legacy_flags::GetParallelCheckOpFlags();
+ if (flags->parallel_check_failfast) {
+ LOG(QFATAL) << "failfast on first parallel-check failure";
+ }
+ } else {
+ VLOG(1) << "check passed for " << def().name() << " output " << i
+ << " num_elts: " << num_elts;
+ }
+
+ // Propagate the std value.
+ if (IsRefType(ctx->input_dtype(i))) {
+ ctx->forward_ref_input_to_ref_output(i, i);
+ } else {
+ ctx->set_output(i, ctx->input(i));
+ }
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ParallelCheckOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelCheck").Device(DEVICE_CPU),
+ ParallelCheckOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
new file mode 100644
index 0000000000..4644121173
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -0,0 +1,199 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+XlaCompilationCache::XlaCompilationCache(const XlaCompiler::Options& options)
+ : compiler_(options) {}
+
+XlaCompilationCache::~XlaCompilationCache() = default;
+
+string XlaCompilationCache::DebugString() {
+ return "XLA JIT compilation cache";
+}
+
+// Compute a string signature which encodes the shapes of the
+// arguments in the supplied list.
+string XlaCompilationCache::SignatureDebugString(const Signature& sig) {
+ string result = sig.name;
+ for (const auto& a : sig.arg_types) {
+ strings::StrAppend(&result, ",", DataTypeString(a.first),
+ a.second.DebugString());
+ }
+
+ for (const auto& v : sig.arg_values) {
+ strings::StrAppend(&result, "; ", v.first, ":", v.second.DebugString());
+ }
+ return result;
+}
+
+bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
+ if (name != other.name) return false;
+ if (arg_types != other.arg_types) return false;
+
+ if (arg_values.size() != other.arg_values.size()) return false;
+ for (int i = 0; i < arg_values.size(); ++i) {
+ if (arg_values[i].first != other.arg_values[i].first ||
+ arg_values[i].second.tensor_data() !=
+ other.arg_values[i].second.tensor_data()) {
+ return false;
+ }
+ }
+ return true;
+}
+
+uint64 XlaCompilationCache::Signature::Hash::operator()(
+ const XlaCompilationCache::Signature& signature) const {
+ uint64 h = std::hash<string>()(signature.name);
+ for (const auto& arg : signature.arg_types) {
+ h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
+ h = Hash64Combine(h, std::hash<int>()(arg.second.dims()));
+ for (int dim : arg.second.dim_sizes()) {
+ h = Hash64Combine(h, std::hash<int>()(dim));
+ }
+ }
+ for (const auto& arg : signature.arg_values) {
+ h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
+ h = Hash64Combine(h, Hash64(arg.second.tensor_data().data(),
+ arg.second.tensor_data().size()));
+ }
+ return h;
+}
+
+namespace {
+
+// Builds a XlaCompiler::Argument vector from the arguments to the _XlaLaunch
+// op. The first `num_constant_args` arguments must be host-memory Tensors.
+std::vector<XlaCompiler::Argument> BuildArguments(int num_constant_args,
+ OpKernelContext* ctx) {
+ std::vector<XlaCompiler::Argument> args(ctx->num_inputs());
+ int parameter_num = 0;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ args[i].type = ctx->input(i).dtype();
+ args[i].shape = ctx->input(i).shape();
+ if (i < num_constant_args || ctx->input(i).NumElements() == 0) {
+ args[i].parameter = -1;
+ args[i].constant_value = ctx->input(i);
+ } else {
+ args[i].parameter = parameter_num;
+ ++parameter_num;
+ }
+ }
+ return args;
+}
+
+} // namespace
+
+Status XlaCompilationCache::Compile(
+ const NameAttrList& function, int num_constant_args, OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult** compilation_result,
+ xla::LocalExecutable** executable) {
+ VLOG(1) << "XlaCompilationCache::Compile " << DebugString();
+
+ if (VLOG_IS_ON(2)) {
+ std::vector<string> argshapes;
+ VLOG(2) << "num_inputs = " << ctx->num_inputs()
+ << " num_constant_args= " << num_constant_args;
+ for (int i = 0; i < ctx->num_inputs(); i++) {
+ TensorShape shape = ctx->input(i).shape();
+ VLOG(2) << i << ": dtype=" << ctx->input_dtype(i)
+ << " present=" << ctx->has_input(i)
+ << " shape=" << shape.DebugString();
+ argshapes.push_back(shape.DebugString());
+ }
+ VLOG(2) << "num_outputs = " << ctx->num_outputs();
+ for (int i = 0; i < ctx->num_outputs(); i++) {
+ VLOG(2) << i << ": dtype=" << ctx->expected_output_dtype(i);
+ }
+ }
+ Signature signature;
+ signature.name = Canonicalize(function.name(), function.attr());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ signature.arg_types.emplace_back(ctx->input_dtype(i),
+ ctx->input(i).shape());
+ if (i < num_constant_args) {
+ signature.arg_values.emplace_back(i, ctx->input(i));
+ }
+ }
+
+ VLOG(2) << "Signature: " << SignatureDebugString(signature);
+ // The outer lock protects the existence of the cache entry. It does not
+ // protect the contents of the cache entry.
+ Entry* entry;
+ {
+ mutex_lock lock(mu_);
+ // Find or create a cache entry.
+ std::unique_ptr<Entry>& e = cache_[signature];
+ if (!e) {
+ e.reset(new Entry);
+ }
+ entry = e.get();
+ }
+
+ // Acquire the cache entry lock and compile, if necessary.
+ // TODO(phawkins): this locking will need to be restructured when we implement
+ // cache eviction.
+ mutex_lock entry_lock(entry->mu);
+ if (!entry->compiled) {
+ // Do the actual JIT compilation without holding the lock (it can take
+ // a long time.)
+ std::vector<XlaCompiler::Argument> args =
+ BuildArguments(num_constant_args, ctx);
+
+ std::unique_ptr<FunctionLibraryRuntime> flr(NewFunctionLibraryRuntime(
+ compiler_.device_mgr(), ctx->env(), compiler_.device(),
+ TF_GRAPH_DEF_VERSION,
+ ctx->function_library()->GetFunctionLibraryDefinition(),
+ OptimizerOptions(), nullptr /* custom_kernel_creator */));
+
+ entry->compiled = true;
+ entry->compilation_status = compiler_.CompileFunction(
+ flr.get(), function, args, &entry->compilation_result);
+ }
+ *compilation_result = &entry->compilation_result;
+ if (entry->compilation_status.ok() && executable) {
+ if (entry->executable == nullptr &&
+ !entry->compilation_result.computation.IsNull()) {
+ entry->compilation_status = compiler_.BuildExecutable(
+ entry->compilation_result, &entry->executable);
+ }
+ *executable = entry->executable.get();
+ }
+
+ Status status = entry->compilation_status;
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
new file mode 100644
index 0000000000..44d76db0fd
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -0,0 +1,112 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// The XlaCompilationCache class caches the results of the XlaCompiler class,
+// which converts a Tensorflow graph into a compiled XLA compilation.
+//
+// Since XLA computations must have static shapes, the cache generates a new
+// XLA computation for each new set of input shapes.
+//
+// Currently no cache eviction policy is implemented and the cache grows without
+// bound.
+class XlaCompilationCache : public ResourceBase {
+ public:
+ explicit XlaCompilationCache(const XlaCompiler::Options& options);
+ ~XlaCompilationCache() override;
+
+ // Compiles a function into a XlaCompiler::CompilationResult that can be used
+ // to execute an XLA Computation. `compilation_result` must be non-null.
+ // If `executable` is non-null, also builds an xla::LocalExecutable and sets
+ // `executable to point to it. The resulting executable pointer may be null if
+ // the computation has no non-constant outputs.
+ // Compilation results are cached.
+ Status Compile(const NameAttrList& function, int num_constant_args,
+ OpKernelContext* ctx,
+ const XlaCompiler::CompilationResult** compilation_result,
+ xla::LocalExecutable** executable);
+
+ xla::Client* client() const { return compiler_.client(); }
+
+ string DebugString() override;
+
+ private:
+ XlaCompiler compiler_;
+ std::unique_ptr<FunctionLibraryRuntime> function_library_runtime_;
+
+ // Describes the types, shapes and any compile-time constant arguments
+ // to a kernel.
+ struct Signature {
+ string name;
+
+ std::vector<std::pair<DataType, TensorShape>> arg_types;
+
+ // List of (argument #, value) pairs for arguments whose values are
+ // part of the JIT signature, and that are therefore constants in any given
+ // JIT compilation. Tensors must be in host memory.
+ std::vector<std::pair<int, Tensor>> arg_values;
+
+ bool operator==(const Signature& other) const;
+
+ struct Hash {
+ uint64 operator()(const Signature& signature) const;
+ };
+ };
+ static string SignatureDebugString(const Signature& sig);
+
+ // The value associated with a cache entry.
+ struct Entry {
+ mutex mu;
+
+ // Have we tried compiling this entry?
+ bool compiled = false;
+
+ // Did compilation succeed?
+ Status compilation_status GUARDED_BY(mu);
+
+ // Output of the XlaCompiler.
+ XlaCompiler::CompilationResult compilation_result GUARDED_BY(mu);
+
+ // The XLA executable compiled from <computation>. May be null if no
+ // executable has been built.
+ std::unique_ptr<xla::LocalExecutable> executable GUARDED_BY(mu);
+ };
+
+ mutex mu_;
+ std::unordered_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
+ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
new file mode 100644
index 0000000000..92784a5358
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Registers the XLA_CPU device, which is an XlaDevice instantiation that runs
+// operators using XLA via the XLA "Host" (CPU) backend.
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+const char* const DEVICE_XLA_CPU = "XLA_CPU";
+
+class XlaCpuDeviceFactory : public DeviceFactory {
+ public:
+ Status CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
+};
+
+Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ static XlaDeviceOpRegistrations* registrations =
+ RegisterXlaDeviceKernels(DEVICE_XLA_CPU, DEVICE_CPU_XLA_JIT);
+ (void)registrations;
+
+ std::unique_ptr<XlaDevice> device;
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+ DEVICE_CPU_XLA_JIT, options, name_prefix,
+ &device));
+ devices->push_back(device.release());
+ return Status::OK();
+}
+
+REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
+
+// Kernel registrations
+
+constexpr std::array<DataType, 5> kAllXlaCpuTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
+
+REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaDeviceLaunchOp, kAllXlaCpuTypes);
+REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
new file mode 100644
index 0000000000..a3fb5b4106
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -0,0 +1,219 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_device.h"
+
+#include <stdlib.h>
+#include <unordered_set>
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device_context.h"
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/public/version.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace tensorflow {
+
+/* static */ Status XlaDevice::Create(
+ const string& platform_name, const string& device_name, int device_ordinal,
+ const string& jit_device_name, const SessionOptions& options,
+ const string& name_prefix, std::unique_ptr<XlaDevice>* device) {
+ VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
+ << device_ordinal;
+
+ // These are no-ops if they have already been done previously for
+ // this device_name/jit_device_name pair.
+ XlaOpRegistry::RegisterJitKernels();
+ XlaOpRegistry::RegisterJitDevice(device_name, jit_device_name,
+ /*requires_jit=*/true);
+
+ auto platform = perftools::gputools::MultiPlatformManager::PlatformWithName(
+ platform_name);
+ if (!platform.ok()) {
+ return StreamExecutorUtil::ConvertStatus(platform.status());
+ }
+
+ const DeviceAttributes attrs = Device::BuildDeviceAttributes(
+ strings::StrCat(name_prefix, "/device:", device_name, ":",
+ device_ordinal),
+ DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
+ strings::StrCat("device: ", device_name, " device"));
+
+ static Allocator* allocator = new XlaDeviceAllocator;
+ device->reset(new XlaDevice(options, attrs, device_ordinal,
+ DeviceType(jit_device_name),
+ platform.ValueOrDie(), allocator));
+ return Status::OK();
+}
+
+XlaDevice::Metadata::Metadata(int device_ordinal,
+ perftools::gputools::Platform* platform,
+ const DeviceType& device_type)
+ : device_ordinal_(device_ordinal),
+ device_type_(device_type),
+ platform_(platform) {}
+
+int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
+
+perftools::gputools::Platform* XlaDevice::Metadata::platform() const {
+ return platform_;
+}
+
+XlaDevice::Metadata::~Metadata() {}
+
+xla::LocalClient* XlaDevice::Metadata::client() const {
+ auto client = xla::ClientLibrary::GetOrCreateLocalClient(platform_);
+ return client.ValueOrDie();
+}
+
+const DeviceType& XlaDevice::Metadata::jit_device_type() const {
+ return device_type_;
+}
+
+string XlaDevice::Metadata::DebugString() { return "XLA device metadata"; }
+
+XlaDevice::XlaDevice(const SessionOptions& options,
+ const DeviceAttributes& attrs, int device_ordinal,
+ const DeviceType& jit_device_name,
+ perftools::gputools::Platform* platform,
+ Allocator* xla_allocator)
+ : LocalDevice(options, attrs, xla_allocator),
+ device_ordinal_(device_ordinal),
+ jit_device_name_(jit_device_name),
+ xla_allocator_(xla_allocator),
+ platform_(platform) {
+ // Store the platform in the resource manager so Ops can retrieve it
+ // e.g., to lazily create a XlaCompilationCache object.
+ TF_CHECK_OK(resource_manager()->Create<Metadata>(
+ resource_manager()->default_container(), "xla_metadata",
+ new Metadata(device_ordinal_, platform_, jit_device_name_)));
+}
+XlaDevice::~XlaDevice() {}
+
+xla::LocalClient* XlaDevice::client() const {
+ // We lazily create the client because the platform commits to the
+ // details of the host hardware when the client is created, so we
+ // don't want to do it until we get a chance to hook the platform up
+ // to a simulator.
+
+ // For now GetOrCreateLocalClient always returns success when passed
+ // a non-null platform. If that changes we may have to plumb in some
+ // way to pass Status back.
+ return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie();
+}
+
+Allocator* XlaDevice::GetAllocator(AllocatorAttributes attr) {
+ if (attr.on_host()) {
+ return cpu_allocator();
+ } else {
+ return xla_allocator_;
+ }
+}
+
+Status XlaDevice::FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) {
+ VLOG(1) << "XlaDevice::FillContextMap";
+ device_context_map->resize(graph->num_node_ids());
+ XlaDeviceContext* ctx = new XlaDeviceContext(client());
+ for (Node* n : graph->nodes()) {
+ VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
+ ctx->Ref();
+ (*device_context_map)[n->id()] = ctx;
+ }
+ return Status::OK();
+}
+
+void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
+ VLOG(1) << "XlaDevice::Compute " << op_kernel->name() << ":"
+ << op_kernel->type_string();
+ op_kernel->Compute(context);
+}
+
+void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) {
+ VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":"
+ << op_kernel->type_string();
+ op_kernel->ComputeAsync(context, done);
+}
+
+Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ VLOG(1) << "XlaDevice::MakeTensorFromProto";
+
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+
+ Status status;
+ if (alloc_attrs.on_host()) {
+ *tensor = parsed;
+ } else {
+ Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
+ Notification n;
+ XlaTransferManager manager(client());
+ manager.CopyCPUTensorToDevice(&parsed, this, &copy,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ *tensor = copy;
+ }
+ VLOG(2) << "Allocated tensor at " << DMAHelper::base(tensor);
+ return status;
+}
+
+XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
+ const char* jit_device) {
+ XlaDeviceOpRegistrations* registrations = new XlaDeviceOpRegistrations;
+ auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* {
+ return new XlaDeviceDummyOp(context);
+ };
+ for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(jit_device)) {
+ KernelDef* def = new KernelDef(*jit_def);
+ def->set_device_type(device);
+ registrations->op_kernel_registrars.emplace_back(
+ new kernel_factory::OpKernelRegistrar(def, "XlaDeviceDummyOp",
+ dummy_factory));
+ }
+ return registrations;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
new file mode 100644
index 0000000000..3de14f3061
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The XlaDevice executes a TensorFlow graph using the XLA linear algebra
+// runtime.
+//
+// Operators assigned to an XlaDevice are compiled into XLA computations.
+// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
+// is managed by XLA.
+//
+// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
+// under different names (e.g., XLA_CPU or XLA_GPU).
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace tensorflow {
+
+class XlaDevice : public LocalDevice {
+ public:
+ // Wrapper class to store metadata about the XlaDevice in the
+ // resource manager, where it can be looked up e.g., when lazily
+ // creating the XlaCompilationCache device.
+ class Metadata : public ResourceBase {
+ public:
+ Metadata(int device_ordinal, perftools::gputools::Platform* platform,
+ const DeviceType& device_type);
+ ~Metadata() override;
+
+ // The index of the device on this host.
+ int device_ordinal() const;
+
+ perftools::gputools::Platform* platform() const;
+ xla::LocalClient* client() const;
+ const DeviceType& jit_device_type() const;
+
+ string DebugString() override;
+
+ private:
+ const int device_ordinal_;
+ const DeviceType device_type_;
+ perftools::gputools::Platform* platform_; // Not owned.
+ };
+
+ // Factory function. 'platform_name' is the name of the XLA platform.
+ // 'device_name' is the name of the Tensorflow device to create.
+ // 'jit_device_name' is the name of the corresponding JIT device.
+ static Status Create(const string& platform_name, const string& device_name,
+ int device_ordinal, const string& jit_device_name,
+ const SessionOptions& options, const string& name_prefix,
+ std::unique_ptr<XlaDevice>* device);
+
+ XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
+ int device_ordinal, const DeviceType& jit_device_name,
+ ::perftools::gputools::Platform* platform,
+ Allocator* xla_allocator);
+ ~XlaDevice() override;
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override;
+ void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
+ void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
+ AsyncOpKernel::DoneCallback done) override;
+ Status Sync() override { return Status::OK(); }
+
+ Status FillContextMap(const Graph* graph,
+ DeviceContextMap* device_context_map) override;
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override;
+
+ xla::LocalClient* client() const;
+
+ private:
+ // Which hardware device in the client's platform this XlaDevice controls.
+ const int device_ordinal_;
+ // The name of the device that is used to compile Ops for this XlaDevice.
+ const DeviceType& jit_device_name_;
+ Allocator* xla_allocator_; // Not owned.
+ ::perftools::gputools::Platform* platform_; // Not owned.
+};
+
+// Builds dummy OpKernel registrations on 'device' for the JIT operators
+// registered on 'jit_device'. Returns ownership of a XlaDeviceOpRegistrations
+// object that encapsulates the kernel registrations.
+struct XlaDeviceOpRegistrations {
+ std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
+ op_kernel_registrars;
+};
+XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
+ const char* jit_device);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
new file mode 100644
index 0000000000..250960d395
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -0,0 +1,181 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_device_context.h"
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+
+namespace tensorflow {
+
+// The contents of tensors allocated by XlaDeviceAllocator.
+struct XlaGlobalData {
+ mutable mutex mu;
+ // May be nullptr if there is no xla::GlobalData backing this Tensor.
+ std::shared_ptr<xla::GlobalData> data GUARDED_BY(mu);
+};
+
+// The allocator used for Tensors assigned to the XLA device. The allocator
+// doesn't actually back Tensors with storage. Instead, each tensor contains
+// a XlaGlobalData that wraps XLA-managed storage.
+XlaDeviceAllocator::XlaDeviceAllocator() = default;
+XlaDeviceAllocator::~XlaDeviceAllocator() = default;
+
+string XlaDeviceAllocator::Name() { return "xla"; }
+
+void* XlaDeviceAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
+ // Regardless of the size requested, always allocate a XlaGlobalData. Respect
+ // the aligment request because there is alignment checking even for Tensors
+ // whose data is never accessed.
+ void* p = port::aligned_malloc(sizeof(XlaGlobalData), alignment);
+ VLOG(2) << "Allocated XLA device tensor " << p;
+ return new (p) XlaGlobalData();
+}
+
+void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
+ XlaGlobalData* global_data = reinterpret_cast<XlaGlobalData*>(ptr);
+ VLOG(2) << "Deallocated XLA device tensor " << ptr;
+ global_data->~XlaGlobalData();
+ port::aligned_free(ptr);
+}
+
+void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
+
+// Don't run any constructors or destructors for complex objects,
+// since there is no backing store for the tensor to run them
+// on. strings are the only complex objects currently stored in
+// Tensors. If others are added, this set of overrides must be
+// extended to include them.
+void XlaDeviceAllocator::RunStringCtor(string* p, size_t n) {}
+void XlaDeviceAllocator::RunStringDtor(string* p, size_t n) {}
+void XlaDeviceAllocator::RunResourceCtor(ResourceHandle* p, size_t n) {}
+void XlaDeviceAllocator::RunResourceDtor(ResourceHandle* p, size_t n) {}
+
+static const XlaGlobalData* CastTensorToXlaGlobalData(const Tensor& tensor) {
+ const XlaGlobalData* expression =
+ reinterpret_cast<const XlaGlobalData*>(tensor.tensor_data().data());
+ return expression;
+}
+
+static XlaGlobalData* CastTensorToXlaGlobalData(Tensor* tensor) {
+ const XlaGlobalData* expression =
+ reinterpret_cast<const XlaGlobalData*>(tensor->tensor_data().data());
+ return const_cast<XlaGlobalData*>(expression);
+}
+
+XlaTransferManager::XlaTransferManager(xla::Client* client) : client_(client) {}
+
+void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
+ Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const {
+ if (cpu_tensor->NumElements() > 0) {
+ VLOG(2) << "CopyCPUTensorToDevice "
+ << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
+ << " " << reinterpret_cast<const void*>(
+ device_tensor->tensor_data().data())
+ << cpu_tensor->NumElements();
+ xla::Literal literal;
+ Status status = HostTensorToLiteral(*cpu_tensor, &literal);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ auto gd = client_->TransferToServer(literal);
+ if (!gd.ok()) {
+ done(gd.status());
+ return;
+ }
+ SetTensorGlobalData(
+ std::shared_ptr<xla::GlobalData>(std::move(gd.ValueOrDie())),
+ device_tensor);
+ } else {
+ VLOG(2) << "CopyCPUTensorToDevice empty tensor";
+ }
+ done(Status::OK());
+}
+
+void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ StringPiece tensor_name,
+ Device* device,
+ Tensor* cpu_tensor,
+ StatusCallback done) {
+ if (device_tensor->NumElements() > 0) {
+ VLOG(2) << "CopyDeviceTensorToCPU"
+ << reinterpret_cast<const void*>(
+ device_tensor->tensor_data().data())
+ << " "
+ << reinterpret_cast<const void*>(cpu_tensor->tensor_data().data())
+ << device_tensor->NumElements();
+ std::shared_ptr<xla::GlobalData> global_data =
+ GetTensorGlobalData(*device_tensor);
+
+ xla::Shape shape;
+ Status status =
+ TensorShapeToXLAShape(cpu_tensor->dtype(), cpu_tensor->shape(), &shape);
+ if (!status.ok()) {
+ done(status);
+ return;
+ }
+ auto result = client_->Transfer(*global_data, &shape);
+ if (!result.ok()) {
+ done(result.status());
+ return;
+ }
+ const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie());
+ void* dst_ptr = DMAHelper::base(cpu_tensor);
+ size_t total_bytes = cpu_tensor->TotalBytes();
+ memcpy(dst_ptr, src_ptr, total_bytes);
+ } else {
+ VLOG(2) << "CopyDeviceTensorToCPU empty tensor";
+ }
+ done(Status::OK());
+}
+
+std::shared_ptr<xla::GlobalData> XlaTransferManager::GetTensorGlobalData(
+ const Tensor& tensor) {
+ const XlaGlobalData* data = CastTensorToXlaGlobalData(tensor);
+ mutex_lock lock(data->mu);
+ CHECK(data->data);
+ return data->data;
+}
+
+void XlaTransferManager::SetTensorGlobalData(
+ std::shared_ptr<xla::GlobalData> global_data, Tensor* tensor) {
+ XlaGlobalData* data = CastTensorToXlaGlobalData(tensor);
+ mutex_lock lock(data->mu);
+ data->data = std::move(global_data);
+}
+
+XlaDeviceContext::XlaDeviceContext(xla::Client* client) : manager_(client) {}
+
+void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
+ Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const {
+ manager_.CopyCPUTensorToDevice(cpu_tensor, device, device_tensor, done);
+}
+
+void XlaDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ StringPiece tensor_name,
+ Device* device, Tensor* cpu_tensor,
+ StatusCallback done) {
+ manager_.CopyDeviceTensorToCPU(device_tensor, tensor_name, device, cpu_tensor,
+ done);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
new file mode 100644
index 0000000000..8ab462b615
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// The allocator used for Tensors assigned to the XLA device. The allocator
+// doesn't actually back Tensors with storage. Instead, each tensor is a thin
+// wrapper around XLA-managed storage.
+class XlaDeviceAllocator : public Allocator {
+ public:
+ XlaDeviceAllocator();
+ ~XlaDeviceAllocator() override;
+
+ string Name() override;
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+ void DeallocateRaw(void* ptr) override;
+ void GetStats(AllocatorStats* stats) override;
+
+ private:
+ void RunStringCtor(string* p, size_t n) override;
+ void RunStringDtor(string* p, size_t n) override;
+ void RunResourceCtor(ResourceHandle* p, size_t n) override;
+ void RunResourceDtor(ResourceHandle* p, size_t n) override;
+};
+
+// Helper class for managing data transfers between host and XLA devices.
+class XlaTransferManager {
+ public:
+ explicit XlaTransferManager(xla::Client* client);
+
+ void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor, StatusCallback done) const;
+ void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ StringPiece tensor_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done);
+
+ // Helper methods to get/set the xla::GlobalData backing a Tensor on the
+ // XlaDevice.
+ static std::shared_ptr<xla::GlobalData> GetTensorGlobalData(
+ const Tensor& tensor);
+ static void SetTensorGlobalData(std::shared_ptr<xla::GlobalData> global_data,
+ Tensor* tensor);
+
+ private:
+ xla::Client* client_;
+};
+
+// DeviceContext for operators assigned to XlaDevice devices. The
+// implementation must inherit from DeviceContext but otherwise just
+// wraps the methods in XlaTransferManager.
+class XlaDeviceContext : public DeviceContext {
+ public:
+ explicit XlaDeviceContext(xla::Client* client);
+
+ void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
+ Tensor* device_tensor,
+ StatusCallback done) const override;
+ void CopyDeviceTensorToCPU(const Tensor* device_tensor,
+ StringPiece tensor_name, Device* device,
+ Tensor* cpu_tensor, StatusCallback done) override;
+
+ private:
+ XlaTransferManager manager_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_CONTEXT_H_
diff --git a/tensorflow/compiler/jit/xla_device_launch_op.cc b/tensorflow/compiler/jit/xla_device_launch_op.cc
new file mode 100644
index 0000000000..becfbc1389
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device_launch_op.cc
@@ -0,0 +1,171 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_device_launch_op.h"
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_device_context.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+namespace {
+
+Status BuildCompilationCache(ResourceMgr* rm, XlaCompilationCache** compiler) {
+ XlaDevice::Metadata* metadata;
+ Status s = rm->Lookup<XlaDevice::Metadata>(rm->default_container(),
+ "xla_metadata", &metadata);
+ if (!s.ok()) {
+ return s;
+ }
+ core::ScopedUnref metadata_ref(metadata);
+ XlaCompiler::Options options;
+ options.device_type = metadata->jit_device_type();
+ options.client = metadata->client();
+ options.allow_cpu_custom_calls = false;
+ options.local_executable_has_hybrid_result = false;
+ *compiler = new XlaCompilationCache(options);
+ return Status::OK();
+}
+
+} // namespace
+
+XlaDeviceLaunchOp::XlaDeviceLaunchOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
+ function_ = *func;
+ VLOG(1) << "XlaDeviceLaunch created function="
+ << Canonicalize(function_.name(), function_.attr());
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
+ num_constant_args_ = constant_types.size();
+}
+
+void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaDeviceLaunch::Compute "
+ << Canonicalize(function_.name(), function_.attr());
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
+
+ XlaCompilationCache* compiler;
+ OP_REQUIRES_OK(ctx,
+ rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_compiler", &compiler,
+ [rm](XlaCompilationCache** compiler) {
+ return BuildCompilationCache(rm, compiler);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref compiler_ref(compiler);
+
+ const XlaCompiler::CompilationResult* kernel;
+ OP_REQUIRES_OK(
+ ctx,
+ compiler->Compile(function_, num_constant_args_, ctx, &kernel, nullptr));
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ OP_REQUIRES(ctx, ctx->num_outputs() == kernel->outputs.size(),
+ errors::Internal("Unexpected number of outputs"));
+
+ // Run the computation, if any. There might not be a computation if all
+ // outputs were compile-time constants.
+ std::vector<std::unique_ptr<xla::GlobalData>> outputs;
+ if (!kernel->computation.IsNull()) {
+ auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
+
+ // Convert argument tensors to xla::GlobalData pointers.
+ std::vector<std::shared_ptr<xla::GlobalData>> arg_handles(
+ kernel->xla_input_shapes.size());
+ std::vector<xla::GlobalData*> arg_ptrs(kernel->xla_input_shapes.size());
+ for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
+ int input_num = kernel->xla_input_shapes[i].first;
+ arg_handles[i] =
+ XlaTransferManager::GetTensorGlobalData(ctx->input(input_num));
+ arg_ptrs[i] = arg_handles[i].get();
+ }
+
+ // Execute the computation.
+ xla::ExecutionProfile profile;
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+ auto result = compiler->client()->Execute(
+ kernel->computation, arg_ptrs, &kernel->xla_output_shape, &profile);
+ auto elapsed = env->NowMicros() - start_time;
+ OP_REQUIRES(ctx, result.ok(), result.status());
+
+ VLOG(1) << "Elapsed time: " << elapsed << "us";
+ VLOG(1) << "ExecutionProfile: " << profile.DebugString();
+
+ if (xla::ShapeUtil::IsTuple(kernel->xla_output_shape)) {
+ auto outputs_or_error =
+ compiler->client()->DeconstructTuple(*result.ValueOrDie());
+ OP_REQUIRES(ctx, outputs_or_error.ok(), outputs_or_error.status());
+ outputs = outputs_or_error.ConsumeValueOrDie();
+ } else {
+ outputs.push_back(result.ConsumeValueOrDie());
+ }
+ }
+
+ XlaDeviceContext* device_context = ctx->op_device_context<XlaDeviceContext>();
+
+ // Copy XLA outputs to the operator's outputs.
+ int output_num = 0;
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ Tensor* output;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(i, kernel->outputs[i].shape, &output));
+ if (kernel->outputs[i].is_constant) {
+ // TODO(phawkins): mark constant _XlaLaunch outputs as HostMemory and
+ // remove the copy from this code.
+ Status status;
+ device_context->CopyCPUTensorToDevice(
+ &kernel->outputs[i].constant_value, nullptr, output,
+ [&status](const Status& s) { status = s; });
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ } else {
+ CHECK_LT(output_num, outputs.size());
+ XlaTransferManager::SetTensorGlobalData(
+ std::shared_ptr<xla::GlobalData>(std::move(outputs[output_num])),
+ output);
+ ++output_num;
+ }
+ }
+
+ VLOG(1) << "Done";
+}
+
+XlaDeviceLaunchOp::~XlaDeviceLaunchOp() {
+ VLOG(1) << "XlaDeviceLaunch destroyed";
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device_launch_op.h b/tensorflow/compiler/jit/xla_device_launch_op.h
new file mode 100644
index 0000000000..fbb9319b84
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device_launch_op.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// The XlaDeviceLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaDeviceLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'TlaJit' to compile and execute code which is specific
+// to the shapes of input Tensors.
+class XlaDeviceLaunchOp : public OpKernel {
+ public:
+ explicit XlaDeviceLaunchOp(OpKernelConstruction* ctx);
+ ~XlaDeviceLaunchOp() override;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ NameAttrList function_;
+ int num_constant_args_;
+ Tensor dummy_tensor_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaDeviceLaunchOp);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/jit/xla_device_ops.cc b/tensorflow/compiler/jit/xla_device_ops.cc
new file mode 100644
index 0000000000..74c314c8ed
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device_ops.cc
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+
+#include "tensorflow/compiler/jit/xla_device_context.h"
+
+namespace tensorflow {
+
+void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs,
+ const Tensor& rhs) {
+ std::shared_ptr<xla::GlobalData> gd =
+ XlaTransferManager::GetTensorGlobalData(rhs);
+ XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs);
+}
+
+XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {
+ LOG(FATAL) << "Attempted to execute Op " << name() << "type " << type_string()
+ << " on an XLA device. This should never happen.";
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
new file mode 100644
index 0000000000..1fcb515ddb
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -0,0 +1,118 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Common kernel registrations for XLA devices.
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
+
+#include "tensorflow/compiler/jit/xla_device_launch_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/kernels/assign_op.h"
+#include "tensorflow/core/kernels/constant_op.h"
+#include "tensorflow/core/kernels/control_flow_ops.h"
+#include "tensorflow/core/kernels/identity_op.h"
+#include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/core/kernels/sendrecv_ops.h"
+#include "tensorflow/core/kernels/variable_ops.h"
+
+namespace tensorflow {
+
+// Implementation of Assign for XLA devices.
+class XlaDeviceAssignOp : public AssignOp {
+ public:
+ using AssignOp::AssignOp;
+
+ void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override;
+};
+
+// Dummy OpKernel, used for kernels assigned to an XLA device that should be
+// compiled. Should never be called at runtime since such ops should be
+// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
+// operator on an XLA device but the compiler did not compile it.
+class XlaDeviceDummyOp : public OpKernel {
+ public:
+ explicit XlaDeviceDummyOp(OpKernelConstruction* ctx);
+ void Compute(OpKernelContext* ctx) override;
+};
+
+#define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_XlaLaunch").Device(DEVICE).HostMemory("constants"), KERNEL);
+
+#define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES) \
+ REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE), SendOp); \
+ REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE), RecvOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \
+ REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \
+ ConstantOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
+ REGISTER_KERNEL_BUILDER(Name("Placeholder").Device(DEVICE), \
+ XlaDeviceDummyOp); \
+ \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \
+ VariableOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
+ TemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
+ .Device(DEVICE) \
+ .TypeConstraint("T", TYPES), \
+ DestroyTemporaryVariableOp); \
+ REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
+ .Device(DEVICE) \
+ .TypeConstraint("dtype", TYPES) \
+ .HostMemory("is_initialized"), \
+ IsVariableInitializedOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \
+ XlaDeviceAssignOp); \
+ \
+ REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
+ ControlTriggerOp); \
+ REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
+ REGISTER_KERNEL_BUILDER(Name("Exit").Device(DEVICE), ExitOp); \
+ REGISTER_KERNEL_BUILDER(Name("NextIteration").Device(DEVICE), \
+ NextIterationOp); \
+ REGISTER_KERNEL_BUILDER(Name("Switch").Device(DEVICE).HostMemory("pred"), \
+ SwitchOp); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Merge").Device(DEVICE).HostMemory("value_index"), MergeOp); \
+ REGISTER_KERNEL_BUILDER(Name("LoopCond") \
+ .Device(DEVICE) \
+ .HostMemory("input") \
+ .HostMemory("output"), \
+ IdentityOp);
+
+// TODO(phawkins): do we really need Placeholder? Should it be a real
+// implementation of Placeholder?
+
+// TODO(b/32507444): the registrations for the control flow operators are
+// temporary and exist primarily to work around a bug in the graph partitioning
+// code.
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
new file mode 100644
index 0000000000..731ff7d673
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
+// operators using XLA via the XLA "CUDA" (GPU) backend.
+
+#include "tensorflow/compiler/jit/xla_device.h"
+#include "tensorflow/compiler/jit/xla_device_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+const char* const DEVICE_XLA_GPU = "XLA_GPU";
+
+class XlaGpuDeviceFactory : public DeviceFactory {
+ public:
+ Status CreateDevices(const SessionOptions& options, const string& name_prefix,
+ std::vector<Device*>* devices) override;
+};
+
+Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
+ const string& name_prefix,
+ std::vector<Device*>* devices) {
+ static XlaDeviceOpRegistrations* registrations =
+ RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
+ (void)registrations;
+
+ std::unique_ptr<XlaDevice> device;
+ Status status =
+ XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
+ name_prefix, &device);
+ if (!status.ok()) {
+ // Treat failures as non-fatal; there might not be a GPU in the machine.
+ LOG(WARNING) << "Failed to create XLA_GPU device: " << status;
+ return Status::OK();
+ }
+ devices->push_back(device.release());
+ return Status::OK();
+}
+
+REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
+
+// Kernel registrations
+
+constexpr std::array<DataType, 5> kAllXlaGpuTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
+
+REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaDeviceLaunchOp, kAllXlaGpuTypes);
+REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_local_launch_op.cc b/tensorflow/compiler/jit/xla_local_launch_op.cc
new file mode 100644
index 0000000000..7945e057cf
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_local_launch_op.cc
@@ -0,0 +1,342 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/jit/xla_local_launch_op.h"
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/util/stream_executor_util.h"
+
+namespace gpu = perftools::gputools;
+
+namespace tensorflow {
+
+REGISTER_OP("_XlaLaunch")
+ .Input("constants: Tconstants")
+ .Attr("Tconstants: list(type) >= 0")
+ .Input("args: Targs")
+ .Attr("Targs: list(type) >= 0")
+ .Output("results: Tresults")
+ .Attr("Tresults: list(type) >= 0")
+ .Attr("function: func")
+ .Doc("XLA Launch Op. For use by the XLA JIT only.");
+
+// Adapter class that wraps a Tensorflow allocator as an XLA allocator.
+class XlaAllocator : public xla::DeviceMemoryAllocator {
+ public:
+ XlaAllocator(const perftools::gputools::Platform* platform,
+ OpKernelContext* op_context);
+ ~XlaAllocator() override;
+ xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure = true) override;
+ Status Deallocate(int device_ordinal,
+ perftools::gputools::DeviceMemoryBase* mem) override;
+
+ // Makes 'tensor' a wrapper around the data buffer at 'ptr'. The buffer is
+ // interpreted as having data type 'dtype' and shape 'shape'.
+ Status MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer, DataType dtype,
+ const TensorShape& shape, Tensor* tensor) const;
+
+ private:
+ OpKernelContext* const op_context_;
+
+ // Map from pointer address to the owning Tensor; used by
+ // MakeTensorFromBuffer. Also used to automatically release Tensors when the
+ // allocator is freed.
+ std::unordered_map<void*, Tensor> tensors_;
+};
+
+XlaAllocator::XlaAllocator(const perftools::gputools::Platform* platform,
+ OpKernelContext* op_context)
+ : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {}
+
+XlaAllocator::~XlaAllocator() = default;
+
+xla::StatusOr<perftools::gputools::DeviceMemoryBase> XlaAllocator::Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) {
+ AllocatorAttributes allocator_attrs;
+ allocator_attrs.set_on_host(false);
+
+ AllocationAttributes allocation_attrs;
+ allocation_attrs.no_retry_on_failure = !retry_on_failure;
+
+ Tensor t;
+ Status status = op_context_->allocate_temp(
+ DT_UINT8, TensorShape({static_cast<int64>(size)}), &t, allocator_attrs,
+ allocation_attrs);
+ if (!status.ok()) {
+ VLOG(2) << "Allocation failed " << size;
+ return status;
+ }
+ void* data =
+ reinterpret_cast<void*>(const_cast<char*>(t.tensor_data().data()));
+ TF_RET_CHECK(data != nullptr);
+ tensors_[data] = t;
+ return perftools::gputools::DeviceMemoryBase(data, size);
+}
+
+Status XlaAllocator::Deallocate(int device_ordinal,
+ perftools::gputools::DeviceMemoryBase* mem) {
+ if (mem->opaque() != nullptr) {
+ if (tensors_.erase(mem->opaque()) == 0) {
+ return tensorflow::errors::InvalidArgument("Unknown tensor address");
+ }
+ }
+ return Status::OK();
+}
+
+Status XlaAllocator::MakeTensorFromBuffer(gpu::DeviceMemoryBase buffer,
+ DataType dtype,
+ const TensorShape& shape,
+ Tensor* out_tensor) const {
+ void* ptr = const_cast<void*>(buffer.opaque());
+ auto it = tensors_.find(ptr);
+ if (it == tensors_.end()) {
+ return errors::InvalidArgument("Unknown tensor address");
+ }
+ const Tensor& tensor = it->second;
+
+ int64 output_size = DataTypeSize(dtype) * shape.num_elements();
+ if (tensor.TotalBytes() == output_size) {
+ out_tensor->UnsafeCopyFromInternal(tensor, dtype, shape);
+ } else {
+ Tensor slice = tensor.Slice(0, output_size);
+ out_tensor->UnsafeCopyFromInternal(slice, dtype, shape);
+ }
+ return Status::OK();
+}
+
+XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), device_type_(ctx->device_type()) {
+ const NameAttrList* func;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("function", &func));
+ function_ = *func;
+ DataTypeVector constant_types;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tconstants", &constant_types));
+ num_constant_args_ = constant_types.size();
+}
+
+Status XlaLocalLaunchOp::BuildCompilationCache(XlaCompilationCache** compiler) {
+ gpu::Platform::Id platform_id;
+ if (device_type_ == DeviceType(DEVICE_CPU)) {
+ platform_id = gpu::host::kHostPlatformId;
+ } else if (device_type_ == DeviceType(DEVICE_GPU)) {
+ platform_id = gpu::cuda::kCudaPlatformId;
+ } else {
+ return errors::InvalidArgument("Unknown device type for local _XlaLaunch");
+ }
+
+ auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id);
+ if (!platform.ok()) {
+ return StreamExecutorUtil::ConvertStatus(platform.status());
+ }
+ auto client =
+ xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie());
+ if (!client.ok()) {
+ return client.status();
+ }
+ const string* compiler_device;
+ if (!XlaOpRegistry::GetJitDevice(device_type_.type(), &compiler_device,
+ /*requires_jit=*/nullptr)) {
+ return errors::InvalidArgument("No JIT device registered for ",
+ device_type_.type());
+ }
+ XlaCompiler::Options options;
+ options.device_type = DeviceType(*compiler_device);
+ options.client = client.ValueOrDie();
+ options.allow_cpu_custom_calls = (platform_id == gpu::host::kHostPlatformId);
+ options.local_executable_has_hybrid_result = true;
+ *compiler = new XlaCompilationCache(options);
+ return Status::OK();
+}
+
+void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
+ VLOG(1) << "XlaLocalLaunchOp::Compute "
+ << Canonicalize(function_.name(), function_.attr());
+ // We store information about the JIT-compiled XLA computation
+ // in the ResourceMgr.
+ ResourceMgr* rm = ctx->resource_manager();
+ OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
+
+ gpu::Stream* stream =
+ ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
+
+ XlaCompilationCache* compiler;
+ OP_REQUIRES_OK(ctx,
+ rm->LookupOrCreate<XlaCompilationCache>(
+ rm->default_container(), "xla_compiler", &compiler,
+ [this](XlaCompilationCache** compiler) {
+ return BuildCompilationCache(compiler);
+ }));
+ // Hold the reference to the JIT during evaluation. (We could probably
+ // free it sooner because the ResourceMgr will retain a reference, but
+ // this is more obviously correct.)
+ core::ScopedUnref compiler_ref(compiler);
+
+ xla::LocalClient* client = static_cast<xla::LocalClient*>(compiler->client());
+
+ const XlaCompiler::CompilationResult* kernel;
+ xla::LocalExecutable* executable;
+ OP_REQUIRES_OK(ctx,
+ compiler->Compile(function_, num_constant_args_, ctx, &kernel,
+ &executable));
+
+ VLOG(1) << "Executing XLA Computation...";
+
+ // Builds an XLA allocator for the device.
+ XlaAllocator xla_allocator(client->platform(), ctx);
+ XlaLocalRuntimeContext local_runtime_context;
+
+ std::unique_ptr<xla::ShapedBuffer> output;
+ bool output_is_tuple;
+ if (!kernel->computation.IsNull()) {
+ // Build xla::ShapedBuffers that point directly to the Tensor buffers.
+ std::vector<std::unique_ptr<xla::ShapedBuffer>> arg_buffers;
+ arg_buffers.reserve(kernel->xla_input_shapes.size() + 1);
+ arg_buffers.resize(kernel->xla_input_shapes.size());
+ std::vector<xla::ShapedBuffer*> arg_ptrs(arg_buffers.size());
+
+ // Pass remaining parameters.
+ for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
+ int arg_num = kernel->xla_input_shapes[i].first;
+ const xla::Shape& shape = kernel->xla_input_shapes[i].second;
+ gpu::DeviceMemoryBase dmem(
+ const_cast<char*>(ctx->input(arg_num).tensor_data().data()),
+ ctx->input(arg_num).tensor_data().size());
+
+ arg_buffers[i] =
+ xla::ShapedBuffer::MakeArrayShapedBuffer(
+ shape, client->platform(), client->default_device_ordinal(), dmem)
+ .ConsumeValueOrDie();
+ arg_ptrs[i] = arg_buffers[i].get();
+ }
+
+ // Make the final parameter point at local_runtime_context.
+ if (kernel->requires_runtime_context) {
+ gpu::DeviceMemoryBase local_runtime_context_dmem(
+ &local_runtime_context, sizeof(local_runtime_context));
+ arg_buffers.push_back(
+ xla::ShapedBuffer::MakeArrayShapedBuffer(
+ xla::ShapeUtil::MakeOpaqueShape(), client->platform(),
+ client->default_device_ordinal(), local_runtime_context_dmem)
+ .ConsumeValueOrDie());
+ arg_ptrs.push_back(arg_buffers.back().get());
+ }
+
+ // Execute the computation.
+ VLOG(2) << "Executing computation.";
+ xla::ExecutableRunOptions run_options;
+ run_options.set_stream(stream);
+ run_options.set_allocator(&xla_allocator);
+ run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
+ Env* env = Env::Default();
+ auto start_time = env->NowMicros();
+ auto run_result = executable->Run(arg_ptrs, run_options);
+ OP_REQUIRES(ctx, run_result.ok(), run_result.status());
+
+ if (local_runtime_context.error) {
+ ctx->CtxFailure(errors::InvalidArgument(
+ "Compiled kernel returned error: ", local_runtime_context.error_msg));
+ return;
+ }
+
+ output = std::move(run_result.ValueOrDie());
+ auto elapsed = env->NowMicros() - start_time;
+ VLOG(2) << "Elapsed time: " << elapsed << "us";
+
+ // Computation output should always be a tuple.
+ if (VLOG_IS_ON(2)) {
+ VLOG(2) << "Result tuple shape: " << output->shape().DebugString();
+ }
+ output_is_tuple = xla::ShapeUtil::IsTuple(output->shape());
+ }
+ CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
+
+ // Copy XLA results to the OpOutputList.
+ int output_num = 0;
+ for (int i = 0; i < ctx->num_outputs(); ++i) {
+ if (kernel->outputs[i].is_constant) {
+ // Output is a constant
+ const Tensor& const_tensor = kernel->outputs[i].constant_value;
+ const size_t total_bytes = const_tensor.TotalBytes();
+ if (stream && total_bytes > 0) {
+ // Copy host -> device. (Empty tensors don't have backing buffers.)
+ VLOG(1) << "Constant output tensor on device";
+ Tensor* output_tensor;
+ TF_CHECK_OK(
+ ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+
+ const void* src_ptr = DMAHelper::base(&const_tensor);
+ void* dst_ptr = DMAHelper::base(output_tensor);
+ gpu::DeviceMemoryBase gpu_dst_ptr(dst_ptr, total_bytes);
+ stream->ThenMemcpy(&gpu_dst_ptr, src_ptr, total_bytes);
+ } else {
+ // No copy required.
+ ctx->set_output(i, const_tensor);
+ }
+ } else {
+ const TensorShape& shape = kernel->outputs[i].shape;
+ VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
+
+ gpu::DeviceMemoryBase buffer;
+ if (output_is_tuple) {
+ buffer = output->buffer({output_num});
+ } else {
+ CHECK_EQ(0, output_num);
+ buffer = output->buffer({});
+ }
+ Tensor output_tensor;
+ // Looks up the owning Tensor by buffer address.
+ OP_REQUIRES_OK(ctx, xla_allocator.MakeTensorFromBuffer(
+ buffer, ctx->expected_output_dtype(i), shape,
+ &output_tensor));
+ ctx->set_output(i, output_tensor);
+ ++output_num;
+ }
+
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << ctx->mutable_output(i)->DebugString();
+ }
+ }
+
+ VLOG(1) << "Done";
+}
+
+XlaLocalLaunchOp::~XlaLocalLaunchOp() {
+ VLOG(1) << "XlaLocalLaunchOp destroyed";
+}
+
+REGISTER_KERNEL_BUILDER(Name("_XlaLaunch").Device(DEVICE_CPU),
+ XlaLocalLaunchOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("_XlaLaunch").Device(DEVICE_GPU).HostMemory("constants"),
+ XlaLocalLaunchOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/xla_local_launch_op.h b/tensorflow/compiler/jit/xla_local_launch_op.h
new file mode 100644
index 0000000000..96ae664cbe
--- /dev/null
+++ b/tensorflow/compiler/jit/xla_local_launch_op.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
+#define TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
+
+#include "tensorflow/compiler/jit/xla_compilation_cache.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// XlaLocalLaunchOp is used to replace a region of the TensorFlow graph
+// which will be compiled and executed using XLA. The XlaLocalLaunchOp is
+// responsible for handling interactions with the TensorFlow executor.
+// Once all inputs are present, and their shapes are known, the op can
+// use a 'XlaCompilationCache' to compile and execute code which is specific
+// to the shapes of input Tensors.
+// XlaLocalLaunchOp uses xla::LocalClient::ExecuteLocally and passes
+// arguments into/out of XLA in device memory.
+class XlaLocalLaunchOp : public OpKernel {
+ public:
+ explicit XlaLocalLaunchOp(OpKernelConstruction* ctx);
+ ~XlaLocalLaunchOp() override;
+
+ void Compute(OpKernelContext* ctx) override;
+
+ private:
+ // Builds a XlaCompilationCache class suitable for the current device.
+ Status BuildCompilationCache(XlaCompilationCache** compiler);
+
+ DeviceType device_type_;
+ NameAttrList function_;
+ int num_constant_args_;
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_JIT_XLA_LOCAL_LAUNCH_OP_H_
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
new file mode 100644
index 0000000000..b4f01de4f2
--- /dev/null
+++ b/tensorflow/compiler/tests/BUILD
@@ -0,0 +1,352 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//tensorflow/compiler/tf2xla:internal",
+ ],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
+load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
+load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites")
+
+generate_backend_suites()
+
+py_library(
+ name = "xla_test",
+ testonly = 1,
+ srcs = ["xla_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/compiler:compiler_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cc_library(
+ name = "depthwise_conv2d_test_kernel",
+ testonly = 1,
+ srcs = ["depthwise_conv2d_test_kernel.cc"],
+ deps = ["//tensorflow/core:framework_lite"],
+)
+
+tf_xla_py_test(
+ name = "binary_ops_test",
+ size = "small",
+ srcs = ["binary_ops_test.py"],
+ shard_count = 5,
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:math_ops_gen",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "clustering_test",
+ size = "small",
+ srcs = ["clustering_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "concat_ops_test",
+ size = "small",
+ srcs = ["concat_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:array_ops_gen",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradient_checker",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "conv2d_test",
+ size = "medium",
+ srcs = ["conv2d_test.py"],
+ shard_count = 10,
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "dynamic_stitch_test",
+ size = "small",
+ srcs = ["dynamic_stitch_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "function_test",
+ size = "small",
+ srcs = ["function_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "lrn_ops_test",
+ size = "medium",
+ srcs = ["lrn_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "nary_ops_test",
+ size = "small",
+ srcs = ["nary_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "nullary_ops_test",
+ size = "small",
+ srcs = ["nullary_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "pooling_ops_test",
+ size = "medium",
+ srcs = ["pooling_ops_test.py"],
+ shard_count = 10,
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "reduce_ops_test",
+ size = "medium",
+ srcs = ["reduce_ops_test.py"],
+ shard_count = 5,
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "ternary_ops_test",
+ size = "small",
+ srcs = ["ternary_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
+ name = "unary_ops_test",
+ size = "small",
+ srcs = ["unary_ops_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:nn_ops_gen",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
+ name = "xla_device_test",
+ size = "small",
+ srcs = ["xla_device_test.py"],
+ additional_deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+cuda_py_test(
+ name = "jit_test",
+ size = "medium",
+ srcs = ["jit_test.py"],
+ additional_deps = [
+ "//tensorflow/contrib/compiler:compiler_py",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ ],
+)
+
+cc_library(
+ name = "randomized_tests_library",
+ testonly = 1,
+ srcs = ["randomized_tests.cc"],
+ deps = [
+ "//tensorflow/compiler/jit",
+ "//tensorflow/compiler/jit:common",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "randomized_tests",
+ # This test is randomized, so only run it if explicitly requested.
+ tags = [
+ "manual",
+ "noguitar",
+ "notap",
+ ],
+ deps = [":randomized_tests_library"],
+)
+
+py_library(
+ name = "lstm",
+ testonly = 1,
+ srcs = ["lstm.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
+ name = "lstm_test",
+ srcs = ["lstm_test.py"],
+ additional_deps = [
+ ":lstm",
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:variables",
+ ],
+)
+
+# An example of ahead-of-time compilation using tfcompile. The
+# lstm_layer_inference.pbtxt file was generated by running lstm_test
+# --dump_graph_dir, and the config file was written by hand.
+#
+# Run the following to build a minimal benchmark of the computation on Android:
+# $ bazel build -c opt --config=android_arm \
+# third_party/tensorflow/compiler/tests:lstm_layer_inference_benchmark
+#
+# Currently the resulting binary size is ~190KB
+tf_library(
+ name = "lstm_layer_inference",
+ testonly = 1,
+ config = "lstm_layer_inference.config.pbtxt",
+ cpp_class = "LSTMLayerInference",
+ graph = "lstm_layer_inference.pbtxt",
+ tags = ["manual"],
+ tfcompile_flags = "--xla_cpu_multi_thread_eigen=false",
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
new file mode 100644
index 0000000000..9d197b0646
--- /dev/null
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -0,0 +1,749 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for binary operators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import googletest
+
+
+class BinaryOpsTest(XLATestCase):
+ """Test cases for binary operators."""
+
+ def _testBinary(self, op, a, b, expected, equality_test=None):
+ with self.test_session() as session:
+ with self.test_scope():
+ pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
+ pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
+ output = op(pa, pb)
+ result = session.run(output, {pa: a, pb: b})
+ if equality_test is None:
+ equality_test = self.assertAllClose
+ equality_test(result, expected, rtol=1e-3)
+
+ def ListsAreClose(self, result, expected, rtol):
+ """Tests closeness of two lists of floats."""
+ self.assertEqual(len(result), len(expected))
+ for i in range(len(result)):
+ self.assertAllClose(result[i], expected[i], rtol)
+
+ def testFloatOps(self):
+ for dtype in self.float_types:
+ self._testBinary(
+ gen_math_ops._real_div,
+ np.array([3, 3, -1.5, -8, 44], dtype=dtype),
+ np.array([2, -2, 7, -4, 0], dtype=dtype),
+ expected=np.array(
+ [1.5, -1.5, -0.2142857, 2, float("inf")], dtype=dtype))
+
+ self._testBinary(math_ops.pow, dtype(3), dtype(4), expected=dtype(81))
+
+ self._testBinary(
+ math_ops.pow,
+ np.array([1, 2], dtype=dtype),
+ np.zeros(shape=[0, 2], dtype=dtype),
+ expected=np.zeros(shape=[0, 2], dtype=dtype))
+ self._testBinary(
+ math_ops.pow,
+ np.array([10, 4], dtype=dtype),
+ np.array([2, 3], dtype=dtype),
+ expected=np.array([100, 64], dtype=dtype))
+ self._testBinary(
+ math_ops.pow,
+ dtype(2),
+ np.array([3, 4], dtype=dtype),
+ expected=np.array([8, 16], dtype=dtype))
+ self._testBinary(
+ math_ops.pow,
+ np.array([[2], [3]], dtype=dtype),
+ dtype(4),
+ expected=np.array([[16], [81]], dtype=dtype))
+
+ self._testBinary(
+ gen_math_ops._sigmoid_grad,
+ np.array([4, 3, 2, 1], dtype=dtype),
+ np.array([5, 6, 7, 8], dtype=dtype),
+ expected=np.array([-60, -36, -14, 0], dtype=dtype))
+
+ self._testBinary(
+ gen_math_ops._rsqrt_grad,
+ np.array([4, 3, 2, 1], dtype=dtype),
+ np.array([5, 6, 7, 8], dtype=dtype),
+ expected=np.array([-160, -81, -28, -4], dtype=dtype))
+
+ self._testBinary(
+ gen_nn_ops._softplus_grad,
+ np.array([4, 3, 2, 1], dtype=dtype),
+ np.array([5, 6, 7, 8], dtype=dtype),
+ expected=np.array(
+ [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
+
+ self._testBinary(
+ gen_math_ops._tanh_grad,
+ np.array([4, 3, 2, 1], dtype=dtype),
+ np.array([5, 6, 7, 8], dtype=dtype),
+ expected=np.array([-75, -48, -21, 0], dtype=dtype))
+
+ self._testBinary(
+ gen_nn_ops._relu_grad,
+ np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
+ np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
+ expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype))
+
+ self._testBinary(
+ gen_nn_ops._relu6_grad,
+ np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype),
+ np.array(
+ [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype),
+ expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype))
+
+ self._testBinary(
+ gen_nn_ops._softmax_cross_entropy_with_logits,
+ np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype),
+ np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype),
+ expected=[
+ np.array([1.44019, 2.44019], dtype=dtype),
+ np.array([[-0.067941, -0.112856, -0.063117, 0.243914],
+ [-0.367941, -0.212856, 0.036883, 0.543914]],
+ dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
+ def testIntOps(self):
+ for dtype in self.int_types:
+ self._testBinary(
+ gen_math_ops._truncate_div,
+ np.array([3, 3, -1, -9, -8], dtype=dtype),
+ np.array([2, -2, 7, 2, -4], dtype=dtype),
+ expected=np.array([1, -1, 0, -4, 2], dtype=dtype))
+
+ def testNumericOps(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ math_ops.add,
+ np.array([1, 2], dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([11, 22], dtype=dtype))
+ self._testBinary(
+ math_ops.add,
+ dtype(5),
+ np.array([1, 2], dtype=dtype),
+ expected=np.array([6, 7], dtype=dtype))
+ self._testBinary(
+ math_ops.add,
+ np.array([[1], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[8], [9]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.sub,
+ np.array([1, 2], dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([-9, -18], dtype=dtype))
+ self._testBinary(
+ math_ops.sub,
+ dtype(5),
+ np.array([1, 2], dtype=dtype),
+ expected=np.array([4, 3], dtype=dtype))
+ self._testBinary(
+ math_ops.sub,
+ np.array([[1], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[-6], [-5]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.maximum,
+ np.array([1, 2], dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([10, 20], dtype=dtype))
+ self._testBinary(
+ math_ops.maximum,
+ dtype(5),
+ np.array([1, 20], dtype=dtype),
+ expected=np.array([5, 20], dtype=dtype))
+ self._testBinary(
+ math_ops.maximum,
+ np.array([[10], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[10], [7]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.minimum,
+ np.array([1, 20], dtype=dtype),
+ np.array([10, 2], dtype=dtype),
+ expected=np.array([1, 2], dtype=dtype))
+ self._testBinary(
+ math_ops.minimum,
+ dtype(5),
+ np.array([1, 20], dtype=dtype),
+ expected=np.array([1, 5], dtype=dtype))
+ self._testBinary(
+ math_ops.minimum,
+ np.array([[10], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[7], [2]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.mul,
+ np.array([1, 20], dtype=dtype),
+ np.array([10, 2], dtype=dtype),
+ expected=np.array([10, 40], dtype=dtype))
+ self._testBinary(
+ math_ops.mul,
+ dtype(5),
+ np.array([1, 20], dtype=dtype),
+ expected=np.array([5, 100], dtype=dtype))
+ self._testBinary(
+ math_ops.mul,
+ np.array([[10], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[70], [14]], dtype=dtype))
+
+ self._testBinary(
+ math_ops.squared_difference,
+ np.array([1, 2], dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([81, 324], dtype=dtype))
+ self._testBinary(
+ math_ops.squared_difference,
+ dtype(5),
+ np.array([1, 2], dtype=dtype),
+ expected=np.array([16, 9], dtype=dtype))
+ self._testBinary(
+ math_ops.squared_difference,
+ np.array([[1], [2]], dtype=dtype),
+ dtype(7),
+ expected=np.array([[36], [25]], dtype=dtype))
+
+ self._testBinary(
+ nn_ops.bias_add,
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([2, -1], dtype=dtype),
+ expected=np.array([[3, 1], [5, 3]], dtype=dtype))
+ self._testBinary(
+ nn_ops.bias_add,
+ np.array([[[[1, 2], [3, 4]]]], dtype=dtype),
+ np.array([2, -1], dtype=dtype),
+ expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
+
+ def _testDivision(self, dtype):
+ """Test cases for division operators."""
+ self._testBinary(
+ math_ops.div,
+ np.array([10, 20], dtype=dtype),
+ np.array([10, 2], dtype=dtype),
+ expected=np.array([1, 10], dtype=dtype))
+ self._testBinary(
+ math_ops.div,
+ dtype(40),
+ np.array([2, 20], dtype=dtype),
+ expected=np.array([20, 2], dtype=dtype))
+ self._testBinary(
+ math_ops.div,
+ np.array([[10], [4]], dtype=dtype),
+ dtype(2),
+ expected=np.array([[5], [2]], dtype=dtype))
+
+ self._testBinary(
+ gen_math_ops._floor_div,
+ np.array([3, 3, -1, -9, -8], dtype=dtype),
+ np.array([2, -2, 7, 2, -4], dtype=dtype),
+ expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
+
+ def testIntDivision(self):
+ for dtype in self.int_types:
+ self._testDivision(dtype)
+
+ def testFloatDivision(self):
+ for dtype in self.float_types:
+ self._testDivision(dtype)
+
+ def _testRemainder(self, dtype):
+ """Test cases for remainder operators."""
+ self._testBinary(
+ gen_math_ops._floor_mod,
+ np.array([3, 3, -1, -8], dtype=dtype),
+ np.array([2, -2, 7, -4], dtype=dtype),
+ expected=np.array([1, -1, 6, 0], dtype=dtype))
+ self._testBinary(
+ gen_math_ops._truncate_mod,
+ np.array([3, 3, -1, -8], dtype=dtype),
+ np.array([2, -2, 7, -4], dtype=dtype),
+ expected=np.array([1, 1, -1, 0], dtype=dtype))
+
+ def testIntRemainder(self):
+ for dtype in self.int_types:
+ self._testRemainder(dtype)
+
+ def testFloatRemainder(self):
+ for dtype in self.float_types:
+ self._testRemainder(dtype)
+
+ def testLogicalOps(self):
+ self._testBinary(
+ math_ops.logical_and,
+ np.array([[True, False], [False, True]], dtype=np.bool),
+ np.array([[False, True], [False, True]], dtype=np.bool),
+ expected=np.array([[False, False], [False, True]], dtype=np.bool))
+
+ self._testBinary(
+ math_ops.logical_or,
+ np.array([[True, False], [False, True]], dtype=np.bool),
+ np.array([[False, True], [False, True]], dtype=np.bool),
+ expected=np.array([[True, True], [False, True]], dtype=np.bool))
+
+ def testComparisons(self):
+ self._testBinary(
+ math_ops.equal,
+ np.array([1, 5, 20], dtype=np.float32),
+ np.array([10, 5, 2], dtype=np.float32),
+ expected=np.array([False, True, False], dtype=np.bool))
+ self._testBinary(
+ math_ops.equal,
+ np.float32(5),
+ np.array([1, 5, 20], dtype=np.float32),
+ expected=np.array([False, True, False], dtype=np.bool))
+ self._testBinary(
+ math_ops.equal,
+ np.array([[10], [7], [2]], dtype=np.float32),
+ np.float32(7),
+ expected=np.array([[False], [True], [False]], dtype=np.bool))
+
+ self._testBinary(
+ math_ops.not_equal,
+ np.array([1, 5, 20], dtype=np.float32),
+ np.array([10, 5, 2], dtype=np.float32),
+ expected=np.array([True, False, True], dtype=np.bool))
+ self._testBinary(
+ math_ops.not_equal,
+ np.float32(5),
+ np.array([1, 5, 20], dtype=np.float32),
+ expected=np.array([True, False, True], dtype=np.bool))
+ self._testBinary(
+ math_ops.not_equal,
+ np.array([[10], [7], [2]], dtype=np.float32),
+ np.float32(7),
+ expected=np.array([[True], [False], [True]], dtype=np.bool))
+
+ for greater_op in [math_ops.greater, (lambda x, y: x > y)]:
+ self._testBinary(
+ greater_op,
+ np.array([1, 5, 20], dtype=np.float32),
+ np.array([10, 5, 2], dtype=np.float32),
+ expected=np.array([False, False, True], dtype=np.bool))
+ self._testBinary(
+ greater_op,
+ np.float32(5),
+ np.array([1, 5, 20], dtype=np.float32),
+ expected=np.array([True, False, False], dtype=np.bool))
+ self._testBinary(
+ greater_op,
+ np.array([[10], [7], [2]], dtype=np.float32),
+ np.float32(7),
+ expected=np.array([[True], [False], [False]], dtype=np.bool))
+
+ for greater_equal_op in [math_ops.greater_equal, (lambda x, y: x >= y)]:
+ self._testBinary(
+ greater_equal_op,
+ np.array([1, 5, 20], dtype=np.float32),
+ np.array([10, 5, 2], dtype=np.float32),
+ expected=np.array([False, True, True], dtype=np.bool))
+ self._testBinary(
+ greater_equal_op,
+ np.float32(5),
+ np.array([1, 5, 20], dtype=np.float32),
+ expected=np.array([True, True, False], dtype=np.bool))
+ self._testBinary(
+ greater_equal_op,
+ np.array([[10], [7], [2]], dtype=np.float32),
+ np.float32(7),
+ expected=np.array([[True], [True], [False]], dtype=np.bool))
+
+ for less_op in [math_ops.less, (lambda x, y: x < y)]:
+ self._testBinary(
+ less_op,
+ np.array([1, 5, 20], dtype=np.float32),
+ np.array([10, 5, 2], dtype=np.float32),
+ expected=np.array([True, False, False], dtype=np.bool))
+ self._testBinary(
+ less_op,
+ np.float32(5),
+ np.array([1, 5, 20], dtype=np.float32),
+ expected=np.array([False, False, True], dtype=np.bool))
+ self._testBinary(
+ less_op,
+ np.array([[10], [7], [2]], dtype=np.float32),
+ np.float32(7),
+ expected=np.array([[False], [False], [True]], dtype=np.bool))
+
+ for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]:
+ self._testBinary(
+ less_equal_op,
+ np.array([1, 5, 20], dtype=np.float32),
+ np.array([10, 5, 2], dtype=np.float32),
+ expected=np.array([True, True, False], dtype=np.bool))
+ self._testBinary(
+ less_equal_op,
+ np.float32(5),
+ np.array([1, 5, 20], dtype=np.float32),
+ expected=np.array([False, True, True], dtype=np.bool))
+ self._testBinary(
+ less_equal_op,
+ np.array([[10], [7], [2]], dtype=np.float32),
+ np.float32(7),
+ expected=np.array([[False], [True], [True]], dtype=np.bool))
+
+ def testBroadcasting(self):
+ """Tests broadcasting behavior of an operator."""
+
+ for dtype in self.numeric_types:
+ self._testBinary(
+ math_ops.add,
+ np.array(3, dtype=dtype),
+ np.array([10, 20], dtype=dtype),
+ expected=np.array([13, 23], dtype=dtype))
+ self._testBinary(
+ math_ops.add,
+ np.array([10, 20], dtype=dtype),
+ np.array(4, dtype=dtype),
+ expected=np.array([14, 24], dtype=dtype))
+
+ # [1,3] x [4,1] => [4,3]
+ self._testBinary(
+ math_ops.add,
+ np.array([[10, 20, 30]], dtype=dtype),
+ np.array([[1], [2], [3], [4]], dtype=dtype),
+ expected=np.array(
+ [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
+ dtype=dtype))
+
+ # [3] * [4,1] => [4,3]
+ self._testBinary(
+ math_ops.add,
+ np.array([10, 20, 30], dtype=dtype),
+ np.array([[1], [2], [3], [4]], dtype=dtype),
+ expected=np.array(
+ [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
+ dtype=dtype))
+
+ def testFill(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ array_ops.fill,
+ np.array([], dtype=np.int32),
+ dtype(-42),
+ expected=dtype(-42))
+ self._testBinary(
+ array_ops.fill,
+ np.array([1, 2], dtype=np.int32),
+ dtype(7),
+ expected=np.array([[7, 7]], dtype=dtype))
+ self._testBinary(
+ array_ops.fill,
+ np.array([3, 2], dtype=np.int32),
+ dtype(50),
+ expected=np.array([[50, 50], [50, 50], [50, 50]], dtype=dtype))
+
+ # Helper method used by testMatMul, testSparseMatMul, testBatchMatMul below.
+ def _testMatMul(self, op):
+ for dtype in self.float_types:
+ self._testBinary(
+ op,
+ np.array([[-0.25]], dtype=dtype),
+ np.array([[8]], dtype=dtype),
+ expected=np.array([[-2]], dtype=dtype))
+ self._testBinary(
+ op,
+ np.array([[100, 10, 0.5]], dtype=dtype),
+ np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
+ expected=np.array([[123, 354]], dtype=dtype))
+ self._testBinary(
+ op,
+ np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
+ np.array([[100], [10]], dtype=dtype),
+ expected=np.array([[130], [250], [680]], dtype=dtype))
+ self._testBinary(
+ op,
+ np.array([[1000, 100], [10, 1]], dtype=dtype),
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ expected=np.array([[1300, 2400], [13, 24]], dtype=dtype))
+
+ self._testBinary(
+ op,
+ np.array([], dtype=dtype).reshape((2, 0)),
+ np.array([], dtype=dtype).reshape((0, 3)),
+ expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
+
+ def testMatMul(self):
+ self._testMatMul(math_ops.matmul)
+
+ # TODO(phawkins): failing on GPU, no registered kernel.
+ def DISABLED_testSparseMatMul(self):
+ # Binary wrappers for sparse_matmul with different hints
+ def SparseMatmulWrapperTF(a, b):
+ return tf.sparse_matmul(a, b, a_is_sparse=True)
+
+ def SparseMatmulWrapperFT(a, b):
+ return tf.sparse_matmul(a, b, b_is_sparse=True)
+
+ def SparseMatmulWrapperTT(a, b):
+ return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True)
+
+ self._testMatMul(tf.sparse_matmul)
+ self._testMatMul(SparseMatmulWrapperTF)
+ self._testMatMul(SparseMatmulWrapperFT)
+ self._testMatMul(SparseMatmulWrapperTT)
+
+ def testBatchMatMul(self):
+ # Same tests as for tf.matmul above.
+ self._testMatMul(math_ops.matmul)
+
+ # Tests with batches of matrices.
+ self._testBinary(
+ math_ops.matmul,
+ np.array([[[-0.25]]], dtype=np.float32),
+ np.array([[[8]]], dtype=np.float32),
+ expected=np.array([[[-2]]], dtype=np.float32))
+ self._testBinary(
+ math_ops.matmul,
+ np.array([[[-0.25]], [[4]]], dtype=np.float32),
+ np.array([[[8]], [[2]]], dtype=np.float32),
+ expected=np.array([[[-2]], [[8]]], dtype=np.float32))
+ self._testBinary(
+ math_ops.matmul,
+ np.array(
+ [[[[1000, 100], [10, 1]], [[2000, 200], [20, 2]]],
+ [[[3000, 300], [30, 3]], [[4000, 400], [40, 4]]]],
+ dtype=np.float32),
+ np.array(
+ [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]],
+ [[55, 66], [77, 88]]]],
+ dtype=np.float32),
+ expected=np.array(
+ [[[[1300, 2400], [13, 24]], [[11400, 13600], [114, 136]]],
+ [[[42900, 79200], [429, 792]], [[250800, 299200], [2508, 2992]]]],
+ dtype=np.float32))
+ self._testBinary(
+ math_ops.matmul,
+ np.array([], dtype=np.float32).reshape((2, 2, 0)),
+ np.array([], dtype=np.float32).reshape((2, 0, 3)),
+ expected=np.array(
+ [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]],
+ dtype=np.float32))
+ self._testBinary(
+ math_ops.matmul,
+ np.array([], dtype=np.float32).reshape((0, 2, 4)),
+ np.array([], dtype=np.float32).reshape((0, 4, 3)),
+ expected=np.array([], dtype=np.float32).reshape(0, 2, 3))
+
+ # Regression test for b/31472796.
+ if hasattr(np, "matmul"):
+ x = np.arange(0, 3 * 5 * 16 * 7, dtype=np.float32).reshape((3, 5, 16, 7))
+ self._testBinary(
+ lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
+ x, x,
+ expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
+
+ def testExpandDims(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ array_ops.expand_dims,
+ dtype(7),
+ np.int32(0),
+ expected=np.array([7], dtype=dtype))
+ self._testBinary(
+ array_ops.expand_dims,
+ np.array([42], dtype=dtype),
+ np.int32(0),
+ expected=np.array([[42]], dtype=dtype))
+ self._testBinary(
+ array_ops.expand_dims,
+ np.array([], dtype=dtype),
+ np.int32(0),
+ expected=np.array([[]], dtype=dtype))
+ self._testBinary(
+ array_ops.expand_dims,
+ np.array([[[1, 2], [3, 4]]], dtype=dtype),
+ np.int32(0),
+ expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
+ self._testBinary(
+ array_ops.expand_dims,
+ np.array([[[1, 2], [3, 4]]], dtype=dtype),
+ np.int32(1),
+ expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
+ self._testBinary(
+ array_ops.expand_dims,
+ np.array([[[1, 2], [3, 4]]], dtype=dtype),
+ np.int32(2),
+ expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
+ self._testBinary(
+ array_ops.expand_dims,
+ np.array([[[1, 2], [3, 4]]], dtype=dtype),
+ np.int32(3),
+ expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype))
+
+ def testPad(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ array_ops.pad,
+ np.array(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array(
+ [[1, 2], [2, 1]], dtype=np.int32),
+ expected=np.array(
+ [[0, 0, 0, 0, 0, 0],
+ [0, 0, 1, 2, 3, 0],
+ [0, 0, 4, 5, 6, 0],
+ [0, 0, 0, 0, 0, 0],
+ [0, 0, 0, 0, 0, 0]],
+ dtype=dtype))
+
+ def testReshape(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ array_ops.reshape,
+ np.array([], dtype=dtype),
+ np.array([0, 4], dtype=np.int32),
+ expected=np.zeros(shape=[0, 4], dtype=dtype))
+ self._testBinary(
+ array_ops.reshape,
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([2, 3], dtype=np.int32),
+ expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
+ self._testBinary(
+ array_ops.reshape,
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([3, 2], dtype=np.int32),
+ expected=np.array([[0, 1], [2, 3], [4, 5]], dtype=dtype))
+ self._testBinary(
+ array_ops.reshape,
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([-1, 6], dtype=np.int32),
+ expected=np.array([[0, 1, 2, 3, 4, 5]], dtype=dtype))
+ self._testBinary(
+ array_ops.reshape,
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([6, -1], dtype=np.int32),
+ expected=np.array([[0], [1], [2], [3], [4], [5]], dtype=dtype))
+ self._testBinary(
+ array_ops.reshape,
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([2, -1], dtype=np.int32),
+ expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
+ self._testBinary(
+ array_ops.reshape,
+ np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
+ np.array([-1, 3], dtype=np.int32),
+ expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
+
+ def testSplit(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
+ np.int32(0),
+ np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
+ dtype=dtype),
+ expected=[
+ np.array([[[1], [2]]], dtype=dtype),
+ np.array([[[3], [4]]], dtype=dtype),
+ np.array([[[5], [6]]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
+ self._testBinary(
+ lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
+ np.int32(1),
+ np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
+ dtype=dtype),
+ expected=[
+ np.array([[[1]], [[3]], [[5]]], dtype=dtype),
+ np.array([[[2]], [[4]], [[6]]], dtype=dtype),
+ ],
+ equality_test=self.ListsAreClose)
+
+ def testTile(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ array_ops.tile,
+ np.array([[6]], dtype=dtype),
+ np.array([1, 2], dtype=np.int32),
+ expected=np.array([[6, 6]], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[1], [2]], dtype=dtype),
+ np.array([1, 2], dtype=np.int32),
+ expected=np.array([[1, 1], [2, 2]], dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([3, 2], dtype=np.int32),
+ expected=np.array(
+ [[1, 2, 1, 2],
+ [3, 4, 3, 4],
+ [1, 2, 1, 2],
+ [3, 4, 3, 4],
+ [1, 2, 1, 2],
+ [3, 4, 3, 4]],
+ dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([1, 1], dtype=np.int32),
+ expected=np.array(
+ [[1, 2],
+ [3, 4]],
+ dtype=dtype))
+ self._testBinary(
+ array_ops.tile,
+ np.array([[1, 2]], dtype=dtype),
+ np.array([3, 1], dtype=np.int32),
+ expected=np.array(
+ [[1, 2],
+ [1, 2],
+ [1, 2]],
+ dtype=dtype))
+
+ def testTranspose(self):
+ for dtype in self.numeric_types:
+ self._testBinary(
+ array_ops.transpose,
+ np.zeros(shape=[1, 0, 4], dtype=dtype),
+ np.array([1, 2, 0], dtype=np.int32),
+ expected=np.zeros(shape=[0, 4, 1], dtype=dtype))
+ self._testBinary(
+ array_ops.transpose,
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([0, 1], dtype=np.int32),
+ expected=np.array([[1, 2], [3, 4]], dtype=dtype))
+ self._testBinary(
+ array_ops.transpose,
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([1, 0], dtype=np.int32),
+ expected=np.array([[1, 3], [2, 4]], dtype=dtype))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
new file mode 100644
index 0000000000..7fb8e0a26d
--- /dev/null
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -0,0 +1,78 @@
+"""Build rules for Tensorflow/XLA testing."""
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
+
+def all_backends():
+ if cuda_is_configured():
+ return ["cpu", "gpu"]
+ else:
+ return ["cpu"]
+
+def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None,
+ backends=None, **kwargs):
+ """Generates py_test targets, one per XLA backend.
+
+ This rule generates py_test() targets named name_backend, for each backend
+ in all_backends(). The rule also generates a test suite with named `name` that
+ tests all backends for the test.
+
+ For example, the following rule generates test cases foo_test_cpu,
+ foo_test_gpu, and a test suite name foo_test that tests both.
+ tf_xla_py_test(
+ name="foo_test",
+ srcs="foo_test.py",
+ deps=[...],
+ )
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ tags: Tags to apply to the generated targets.
+ data: Data dependencies of the target.
+ main: Same as py_test's main attribute.
+ backends: A list of backends to test. Supported values include "cpu" and
+ "gpu". If not specified, defaults to all backends.
+ **kwargs: keyword arguments passed onto the generated py_test() rules.
+ """
+ if backends == None:
+ backends = all_backends()
+
+ test_names = []
+ for backend in backends:
+ test_name = "{}_{}".format(name, backend)
+ backend_tags = ["tf_xla_{}".format(backend)]
+ backend_args = []
+ backend_deps = []
+ backend_data = []
+ if backend == "cpu":
+ backend_args += ["--test_device=XLA_CPU",
+ "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
+ elif backend == "gpu":
+ backend_args += ["--test_device=XLA_GPU",
+ "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"]
+ backend_tags += ["requires-gpu-sm35"]
+ else:
+ fail("Unknown backend {}".format(backend))
+
+ native.py_test(
+ name=test_name,
+ srcs=srcs,
+ srcs_version="PY2AND3",
+ args=backend_args,
+ main="{}.py".format(name) if main == None else main,
+ data=data + backend_data,
+ deps=deps + backend_deps,
+ tags=tags + backend_tags,
+ **kwargs
+ )
+ test_names.append(test_name)
+ native.test_suite(name=name, tests=test_names)
+
+def generate_backend_suites(backends=[]):
+ """Generates per-backend test_suites that run all tests for a backend."""
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.test_suite(name="%s_tests" % backend, tags=["tf_xla_%s" % backend])
+
diff --git a/tensorflow/compiler/tests/clustering_test.py b/tensorflow/compiler/tests/clustering_test.py
new file mode 100644
index 0000000000..6993648531
--- /dev/null
+++ b/tensorflow/compiler/tests/clustering_test.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the behavior of the auto-compilation pass."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
+
+
+class ClusteringTest(XLATestCase):
+
+ def testAdd(self):
+ val1 = np.array([4, 3, 2, 1], dtype=np.float32)
+ val2 = np.array([5, 6, 7, 8], dtype=np.float32)
+ expected = val1 + val2
+ with self.test_session():
+ with self.test_scope():
+ input1 = constant_op.constant(val1, name="const1")
+ input2 = constant_op.constant(val2, name="const2")
+ output = math_ops.add(input1, input2)
+ result = output.eval()
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testAddFromCpuMultiple(self):
+ val1 = np.array([4, 3, 2, 1]).astype(np.float32)
+ val2 = np.array([5, 6, 7, 8]).astype(np.float32)
+ expected = val1 + val2
+ with self.test_session():
+ with ops.device(CPU_DEVICE):
+ input1 = constant_op.constant(val1, name="const1")
+ input2 = constant_op.constant(val2, name="const2")
+ with self.test_scope():
+ output = math_ops.add(input1, input2)
+ for _ in xrange(10):
+ result = output.eval()
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testDeadlock(self):
+ # Builds a graph of the form:
+ # x -> y
+ # | \
+ # z -> w
+ # where x and z are placed on the CPU and y and w are placed on the XLA
+ # device. If y and w are clustered for compilation, then the graph will
+ # deadlock since the clustered graph will contain a self-loop.
+ with self.test_session() as sess:
+ with ops.device(CPU_DEVICE):
+ x = array_ops.placeholder(dtypes.float32, [2])
+ with self.test_scope():
+ y = x * 2
+ with ops.device(CPU_DEVICE):
+ z = y * y
+ with self.test_scope():
+ w = y + z
+ result = sess.run(w, {x: [1.5, 0.5]})
+ self.assertAllClose(result, [12., 2.], rtol=1e-3)
+
+ def testHostMemory(self):
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.int32)
+ with self.test_scope():
+ y = x + 1
+ with ops.device(CPU_DEVICE):
+ # Place a computation on the CPU, so y and w cannot be merged into the
+ # same JIT compilation.
+ z = y * 2
+ with self.test_scope():
+ # Argument 'y' is a non-constant output of a previous cluster. Make sure
+ # it is properly copied to host memory so it can be used as a
+ # compile-time constant input for this cluster.
+ w = array_ops.reshape(z, y)
+ result = sess.run(w, {x: [1, 0]})
+ expected = np.array([[4], [2]], dtype=np.int32)
+ self.assertAllClose(expected, result, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/concat_ops_test.py b/tensorflow/compiler/tests/concat_ops_test.py
new file mode 100644
index 0000000000..6f74ae702c
--- /dev/null
+++ b/tensorflow/compiler/tests/concat_ops_test.py
@@ -0,0 +1,374 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for XLA Concat Op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class ConcatTest(XLATestCase):
+
+ def testHStack(self):
+ with self.test_session():
+ p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
+ p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
+ with self.test_scope():
+ c = array_ops.concat_v2([p1, p2], 0)
+ params = {
+ p1: np.random.rand(4, 4).astype("f"),
+ p2: np.random.rand(4, 4).astype("f")
+ }
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ self.assertAllEqual(result[:4, :], params[p1])
+ self.assertAllEqual(result[4:, :], params[p2])
+
+ def testVStack(self):
+ with self.test_session():
+ p1 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
+ p2 = array_ops.placeholder(dtypes.float32, shape=[4, 4])
+ with self.test_scope():
+ c = array_ops.concat_v2([p1, p2], 1)
+ params = {
+ p1: np.random.rand(4, 4).astype("f"),
+ p2: np.random.rand(4, 4).astype("f")
+ }
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ self.assertAllEqual(result[:, :4], params[p1])
+ self.assertAllEqual(result[:, 4:], params[p2])
+
+ def testInt32(self):
+ with self.test_session():
+ p1 = np.random.rand(2, 3).astype("i")
+ p2 = np.random.rand(2, 3).astype("i")
+ x1 = constant_op.constant(p1)
+ x2 = constant_op.constant(p2)
+ with self.test_scope():
+ c = array_ops.concat_v2([x1, x2], 0)
+ result = c.eval()
+ self.assertAllEqual(result[:2, :], p1)
+ self.assertAllEqual(result[2:, :], p2)
+
+ def _testRandom(self, dtype):
+ # Random dims of rank 5
+ shape = np.random.randint(1, 5, size=5)
+ # Random number of tensors, but always > 1.
+ num_tensors = np.random.randint(2, 10)
+ # Random dim to concat on
+ concat_dim = np.random.randint(5)
+ params = {}
+ if dtype == dtypes.bfloat16:
+ dtype_feed = dtypes.float32
+ else:
+ dtype_feed = dtype
+ with self.test_session():
+ p = []
+ for i in np.arange(num_tensors):
+ input_shape = shape
+ input_shape[concat_dim] = np.random.randint(1, 5)
+ placeholder = array_ops.placeholder(dtype_feed, shape=input_shape)
+ p.append(placeholder)
+
+ t = dtype_feed.as_numpy_dtype
+ params[placeholder] = np.random.rand(*input_shape).astype(t)
+
+ if dtype != dtype_feed:
+ concat_inputs = [math_ops.cast(p_i, dtype) for p_i in p]
+ else:
+ concat_inputs = p
+ with self.test_scope():
+ c = array_ops.concat_v2(concat_inputs, concat_dim)
+ if dtype != dtype_feed:
+ c = math_ops.cast(c, dtype_feed)
+ result = c.eval(feed_dict=params)
+
+ self.assertEqual(result.shape, c.get_shape())
+ cur_offset = 0
+
+ for i in np.arange(num_tensors):
+ # The index into the result is the ':' along all dimensions
+ # except the concat_dim. slice(0, size) is used for ':', and
+ # a list of slices is used to index into result.
+ ind = [slice(0, params[p[i]].shape[j]) for j in np.arange(5)]
+ ind[concat_dim] = slice(cur_offset,
+ cur_offset + params[p[i]].shape[concat_dim])
+ cur_offset += params[p[i]].shape[concat_dim]
+ if dtype == dtype_feed:
+ self.assertAllEqual(result[ind], params[p[i]])
+ else:
+ self.assertAllClose(result[ind], params[p[i]], 0.01)
+
+ def testRandom(self):
+ self._testRandom(dtypes.float32)
+ self._testRandom(dtypes.int32)
+
+ def _testGradientsSimple(self):
+ with self.test_session():
+ inp = []
+ inp_tensors = []
+ with self.test_scope():
+ for x in [1, 2, 6]:
+ shape = [10, x, 2]
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ constant_op.constant(
+ [float(y) for y in t.flatten()],
+ shape=shape,
+ dtype=dtypes.float32))
+ c = array_ops.concat_v2(inp_tensors, 1)
+ output_shape = [10, 9, 2]
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = constant_op.constant(
+ [float(x) for x in grad_inp.flatten()], shape=output_shape)
+ grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = array_ops.concat_v2(grad, 1)
+ result = concated_grad.eval()
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsSimpleAll(self):
+ self._testGradientsSimple()
+
+ def _testGradientsFirstDim(self):
+ with self.test_session():
+ inp = []
+ inp_tensors = []
+ with self.test_scope():
+ for x in [1, 2, 6]:
+ shape = [x, 10, 2]
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ constant_op.constant(
+ [float(y) for y in t.flatten()],
+ shape=shape,
+ dtype=dtypes.float32))
+ c = array_ops.concat_v2(inp_tensors, 0)
+ output_shape = [9, 10, 2]
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = constant_op.constant(
+ [float(x) for x in grad_inp.flatten()], shape=output_shape)
+ grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = array_ops.concat_v2(grad, 0)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsFirstDimAll(self):
+ self._testGradientsFirstDim()
+
+ def _testGradientsLastDim(self):
+ with self.test_session():
+ inp = []
+ inp_tensors = []
+ with self.test_scope():
+ for x in [1, 2, 6]:
+ shape = [10, 2, x]
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ constant_op.constant(
+ [float(y) for y in t.flatten()],
+ shape=shape,
+ dtype=dtypes.float32))
+ c = array_ops.concat_v2(inp_tensors, 2)
+ output_shape = [10, 2, 9]
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = constant_op.constant(
+ [float(x) for x in grad_inp.flatten()], shape=output_shape)
+ grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = array_ops.concat_v2(grad, 2)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsLastDimAll(self):
+ self._testGradientsLastDim()
+
+ def _RunAndVerifyGradientsRandom(self):
+ # Random dims of rank 5
+ input_shape = np.random.randint(1, 5, size=5)
+ # Random number of tensors
+ num_tensors = np.random.randint(1, 10)
+ # Random dim to concat on
+ concat_dim = np.random.randint(5)
+ concat_dim_sizes = np.random.randint(1, 5, size=num_tensors)
+ with self.test_session():
+ inp = []
+ inp_tensors = []
+ with self.test_scope():
+ for x in concat_dim_sizes:
+ shape = input_shape
+ shape[concat_dim] = x
+ t = np.random.rand(*shape).astype("f")
+ inp.append(t)
+ inp_tensors.append(
+ constant_op.constant(
+ [float(y) for y in t.flatten()],
+ shape=shape,
+ dtype=dtypes.float32))
+ c = array_ops.concat_v2(inp_tensors, concat_dim)
+ output_shape = input_shape
+ output_shape[concat_dim] = concat_dim_sizes.sum()
+ grad_inp = np.random.rand(*output_shape).astype("f")
+ grad_tensor = constant_op.constant(
+ [float(x) for x in grad_inp.flatten()], shape=output_shape)
+ grad = gradients_impl.gradients([c], inp_tensors, [grad_tensor])
+ concated_grad = array_ops.concat_v2(grad, concat_dim)
+ result = concated_grad.eval()
+
+ self.assertAllEqual(result, grad_inp)
+
+ def testGradientsRandom(self):
+ for _ in range(5):
+ self._RunAndVerifyGradientsRandom()
+
+ # Re-enable once zero-element Retvals are handled correctly.
+ def DISABLED_testZeroSize(self):
+ # Verify that concat doesn't crash and burn for zero size inputs
+ np.random.seed(7)
+ with self.test_session() as sess:
+ with self.test_scope():
+ for shape0 in (), (2,):
+ axis = len(shape0)
+ for shape1 in (), (3,):
+ for n0 in 0, 1, 2:
+ for n1 in 0, 1, 2:
+ x0 = np.random.randn(*(shape0 + (n0,) + shape1))
+ x1 = np.random.randn(*(shape0 + (n1,) + shape1))
+ correct = np.concatenate([x0, x1], axis=axis)
+ # TODO(irving): Make tf.concat handle map, then drop list().
+ xs = list(map(constant_op.constant, [x0, x1]))
+ c = array_ops.concat_v2(xs, axis)
+ self.assertAllEqual(c.eval(), correct)
+ # Check gradients
+ dc = np.random.randn(*c.get_shape().as_list())
+ dxs = sess.run(gradients_impl.gradients(c, xs, dc))
+ self.assertAllEqual(dc, np.concatenate(dxs, axis=axis))
+
+ def testTensorConcatDim0Grad(self):
+ x_shapes = [[20, 7, 3], [10, 7, 3], [14, 7, 3]]
+ output_shape = [44, 7, 3]
+ x_vals = [
+ np.random.random_sample(x_shape).astype(np.float32)
+ for x_shape in x_shapes
+ ]
+ with self.test_session():
+ with self.test_scope():
+ xs = [constant_op.constant(x_val) for x_val in x_vals]
+ output = array_ops.concat_v2(xs, 0)
+ err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
+ output_shape)
+ self.assertLess(err, 1e-4)
+
+ def testTensorConcatDim1Grad(self):
+ x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
+ output_shape = [20, 11, 3]
+ x_vals = [
+ np.random.random_sample(x_shape).astype(np.float32)
+ for x_shape in x_shapes
+ ]
+ with self.test_session():
+ with self.test_scope():
+ xs = [constant_op.constant(x_val) for x_val in x_vals]
+ output = array_ops.concat_v2(xs, 1)
+ err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
+ output_shape)
+ self.assertLess(err, 1e-4)
+
+ def testConcatTuple(self):
+ c1 = np.random.rand(4, 4).astype(np.float32)
+ c2 = np.random.rand(4, 4).astype(np.float32)
+ with self.test_session():
+ with self.test_scope():
+ concat_list_t = array_ops.concat_v2([c1, c2], 0)
+ concat_tuple_t = array_ops.concat_v2((c1, c2), 0)
+ self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
+
+ def testConcatNoScalars(self):
+ with self.test_session():
+ with self.test_scope():
+ scalar = constant_op.constant(7)
+ dim = array_ops.placeholder(dtypes.int32)
+ with self.assertRaisesRegexp(
+ ValueError, r"Can't concatenate scalars \(use tf\.pack instead\)"):
+ array_ops.concat_v2([scalar, scalar, scalar], dim)
+
+
+class ConcatOffsetTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ cdim = constant_op.constant(1, dtypes.int32)
+ s0 = constant_op.constant([2, 3, 5], dtypes.int32)
+ s1 = constant_op.constant([2, 7, 5], dtypes.int32)
+ s2 = constant_op.constant([2, 20, 5], dtypes.int32)
+ off = gen_array_ops._concat_offset(cdim, [s0, s1, s2])
+ ans = sess.run(off)
+ self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
+
+
+class PackTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ s0 = constant_op.constant([2, 3, 5], dtypes.int32)
+ s1 = constant_op.constant([2, 7, 5], dtypes.int32)
+ s2 = constant_op.constant([2, 20, 5], dtypes.int32)
+ packed = array_ops.stack([s0, s1, s2])
+ ans = sess.run(packed)
+ self.assertAllEqual(ans, [[2, 3, 5], [2, 7, 5], [2, 20, 5]])
+
+ def testScalars(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ s0 = constant_op.constant(2, dtypes.int32)
+ s1 = constant_op.constant(3, dtypes.int32)
+ s2 = constant_op.constant(5, dtypes.int32)
+ packed = array_ops.stack([s0, s1, s2])
+ ans = sess.run(packed)
+ self.assertAllEqual(ans, [2, 3, 5])
+
+ def testEmpty(self):
+ with self.test_session() as sess:
+ with self.test_scope():
+ s0 = constant_op.constant([[]], dtypes.int32)
+ s1 = constant_op.constant([[]], dtypes.int32)
+ s2 = constant_op.constant([[]], dtypes.int32)
+ packed = array_ops.stack([s0, s1, s2])
+ ans = sess.run(packed)
+ self.assertAllEqual(ans, [[[]], [[]], [[]]])
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py
new file mode 100644
index 0000000000..01cfbd9f7c
--- /dev/null
+++ b/tensorflow/compiler/tests/conv2d_test.py
@@ -0,0 +1,526 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Conv2D via the XLA JIT.
+
+The canned results in these tests are created by running each test using the
+Tensorflow CPU device and saving the output.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import nn_impl
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import googletest
+
+
+class Conv2DTest(XLATestCase):
+
+ def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected):
+ """Tests that tf.nn.conv2d produces the expected value.
+
+ Args:
+ input_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ stride: Stride.
+ padding: Padding type.
+ expected: Expected output.
+ """
+ total_size_1 = np.prod(input_sizes)
+ total_size_2 = np.prod(filter_sizes)
+ x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
+ x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes)
+ strides = [1, stride, stride, 1]
+
+ with self.test_session() as sess:
+ with self.test_scope():
+ t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
+ t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
+ out = nn_ops.conv2d(
+ t1, t2, strides=strides, padding=padding, data_format="NHWC")
+ value = sess.run(out, {t1: x1, t2: x2})
+ self.assertArrayNear(expected, np.ravel(value), 1e-3)
+
+ def testConv2D1x1Filter(self):
+ expected_output = [
+ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0,
+ 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0
+ ]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[1, 1, 3, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2Filter(self):
+ expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2Filter(self):
+ expected_output = [
+ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0,
+ 936.0, 1029.0
+ ]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[1, 2, 3, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2(self):
+ expected_output = [2271.0, 2367.0, 2463.0]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ stride=2,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2Same(self):
+ expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ stride=2,
+ padding="SAME",
+ expected=expected_output)
+
+
+class Conv2DBackpropInputTest(XLATestCase):
+
+ def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride,
+ padding, expected):
+ """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output.
+
+ Args:
+ input_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ out_backprop_sizes: Output gradients tensor dimensions.
+ stride: Stride.
+ padding: Padding type.
+ expected: Expected output.
+ """
+ total_size_1 = np.prod(filter_sizes)
+ total_size_2 = np.prod(out_backprop_sizes)
+ x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes)
+ x2 = np.arange(
+ 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes)
+ strides = [1, stride, stride, 1]
+
+ with self.test_session() as sess:
+ with self.test_scope():
+ t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes)
+ t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
+ out = gen_nn_ops.conv2d_backprop_input(
+ input_sizes=input_sizes,
+ filter=t1,
+ out_backprop=t2,
+ strides=strides,
+ padding=padding,
+ data_format="NHWC")
+ value = sess.run(out, {t1: x1, t2: x2})
+ self.assertArrayNear(expected, np.ravel(value), 1e-3)
+
+ def testConv2D1x1Filter(self):
+ expected_output = [
+ 5, 11, 17, 11, 25, 39, 17, 39, 61, 23, 53, 83, 29, 67, 105, 35, 81, 127,
+ 41, 95, 149, 47, 109, 171, 53, 123, 193, 59, 137, 215, 65, 151, 237, 71,
+ 165, 259, 77, 179, 281, 83, 193, 303, 89, 207, 325, 95, 221, 347.
+ ]
+ self._VerifyValues(
+ input_sizes=[1, 4, 4, 3],
+ filter_sizes=[1, 1, 3, 2],
+ out_backprop_sizes=[1, 4, 4, 2],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterStride3Width5(self):
+ expected_output = [1, 2, 0, 2, 4]
+ self._VerifyValues(
+ input_sizes=[1, 1, 5, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterStride3Width6(self):
+ expected_output = [1, 2, 0, 2, 4, 0]
+ self._VerifyValues(
+ input_sizes=[1, 1, 6, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterStride3Width7(self):
+ expected_output = [1, 2, 0, 2, 4, 0, 0]
+ self._VerifyValues(
+ input_sizes=[1, 1, 7, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterC1Same(self):
+ expected_output = [1, 4, 7, 7, 23, 33]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ out_backprop_sizes=[1, 2, 3, 1],
+ stride=1,
+ padding="SAME",
+ expected=expected_output)
+
+ def testConv2D2x2Filter(self):
+ expected_output = [
+ 14, 32, 50, 100, 163, 226, 167, 212, 257, 122, 140, 158, 478, 541, 604,
+ 437, 482, 527
+ ]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ out_backprop_sizes=[1, 1, 2, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterSame(self):
+ expected_output = [
+ 14, 32, 50, 100, 163, 226, 217, 334, 451, 190, 307, 424, 929, 1217,
+ 1505, 1487, 1883, 2279
+ ]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ out_backprop_sizes=[1, 2, 3, 3],
+ stride=1,
+ padding="SAME",
+ expected=expected_output)
+
+ def testConv2D1x2Filter(self):
+ expected_output = [1, 4, 4, 3, 10, 8, 5, 16, 12]
+ self._VerifyValues(
+ input_sizes=[1, 3, 3, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 3, 2, 1],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterSame(self):
+ expected_output = [1, 4, 7, 4, 13, 16, 7, 22, 25]
+ self._VerifyValues(
+ input_sizes=[1, 3, 3, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 3, 3, 1],
+ stride=1,
+ padding="SAME",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2(self):
+ expected_output = [1, 2, 5, 4, 6, 0, 0, 0, 0, 0, 3, 6, 13, 8, 12]
+ self._VerifyValues(
+ input_sizes=[1, 3, 5, 1],
+ filter_sizes=[1, 3, 1, 1],
+ out_backprop_sizes=[1, 2, 2, 1],
+ stride=2,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2Same(self):
+ expected_output = [1, 2, 2, 3, 4, 6]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=2,
+ padding="SAME",
+ expected=expected_output)
+
+
+class Conv2DBackpropFilterTest(XLATestCase):
+
+ def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride,
+ padding, expected):
+ """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output.
+
+ Args:
+ input_sizes: Input tensor dimensions in
+ [batch, input_rows, input_cols, input_depth].
+ filter_sizes: Filter tensor dimensions in
+ [kernel_rows, kernel_cols, input_depth, output_depth].
+ out_backprop_sizes: Output gradients tensor dimensions.
+ stride: Stride.
+ padding: Padding type.
+ expected: Expected output.
+ """
+
+ total_size_1 = np.prod(input_sizes)
+ total_size_2 = np.prod(out_backprop_sizes)
+ x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes)
+ x2 = np.arange(
+ 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes)
+ strides = [1, stride, stride, 1]
+
+ with self.test_session() as sess:
+ with self.test_scope():
+ t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes)
+ t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes)
+ tensor = gen_nn_ops.conv2d_backprop_filter(
+ input=t1,
+ filter_sizes=filter_sizes,
+ out_backprop=t2,
+ strides=strides,
+ padding=padding,
+ data_format="NHWC")
+
+ value = sess.run(tensor, {t1: x1, t2: x2})
+ self.assertArrayNear(expected, np.ravel(value), 1e-5)
+
+ def testConv2D1x1Filter(self):
+ expected_output = [8056, 8432, 8312, 8704, 8568, 8976]
+ self._VerifyValues(
+ input_sizes=[1, 4, 4, 3],
+ filter_sizes=[1, 1, 3, 2],
+ out_backprop_sizes=[1, 4, 4, 2],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2Filter(self):
+ expected_output = [120, 141]
+ self._VerifyValues(
+ input_sizes=[1, 3, 3, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 3, 2, 1],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterDepth1(self):
+ expected_output = [5, 8, 14, 17]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2Filter(self):
+ expected_output = [
+ 17, 22, 27, 22, 29, 36, 27, 36, 45, 32, 43, 54, 37, 50, 63, 42, 57, 72,
+ 62, 85, 108, 67, 92, 117, 72, 99, 126, 77, 106, 135, 82, 113, 144, 87,
+ 120, 153
+ ]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 3],
+ filter_sizes=[2, 2, 3, 3],
+ out_backprop_sizes=[1, 1, 2, 3],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterStride3Width5(self):
+ expected_output = [9, 12]
+ self._VerifyValues(
+ input_sizes=[1, 1, 5, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterStride3Width6(self):
+ expected_output = [9, 12]
+ self._VerifyValues(
+ input_sizes=[1, 1, 6, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x2FilterStride3Width7(self):
+ expected_output = [9, 12]
+ self._VerifyValues(
+ input_sizes=[1, 1, 7, 1],
+ filter_sizes=[1, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=3,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x3Filter(self):
+ expected_output = [5, 8, 11]
+ self._VerifyValues(
+ input_sizes=[1, 1, 4, 1],
+ filter_sizes=[1, 3, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=1,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D1x3FilterSame(self):
+ expected_output = [20, 30, 20]
+ self._VerifyValues(
+ input_sizes=[1, 1, 4, 1],
+ filter_sizes=[1, 3, 1, 1],
+ out_backprop_sizes=[1, 1, 4, 1],
+ stride=1,
+ padding="SAME",
+ expected=expected_output)
+
+ def testConv2D1x3FilterSameOutbackprop2(self):
+ expected_output = [7, 10, 3]
+ self._VerifyValues(
+ input_sizes=[1, 1, 4, 1],
+ filter_sizes=[1, 3, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=2,
+ padding="SAME",
+ expected=expected_output)
+
+ def testConv2D2x2FilterC1Same(self):
+ expected_output = [91, 58, 32, 17]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ out_backprop_sizes=[1, 2, 3, 1],
+ stride=1,
+ padding="SAME",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2(self):
+ expected_output = [92, 102, 112]
+ self._VerifyValues(
+ input_sizes=[1, 3, 5, 1],
+ filter_sizes=[1, 3, 1, 1],
+ out_backprop_sizes=[1, 2, 2, 1],
+ stride=2,
+ padding="VALID",
+ expected=expected_output)
+
+ def testConv2D2x2FilterStride2Same(self):
+ expected_output = [7, 2, 16, 5]
+ self._VerifyValues(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ out_backprop_sizes=[1, 1, 2, 1],
+ stride=2,
+ padding="SAME",
+ expected=expected_output)
+
+
+class DepthwiseConv2DTest(XLATestCase):
+
+ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
+
+ def ConfigsToTest(self):
+ input_sizes = [[4, 35, 35, 2], [4, 147, 147, 2], [3, 299, 299, 3],
+ [5, 183, 183, 1]]
+ filter_sizes = [[5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]]
+ strides = [1, 3, 2, 2]
+ # pylint: disable=invalid-name
+ VALID = "VALID"
+ SAME = "SAME"
+ # pylint: enable=invalid-name
+ paddings = [SAME, VALID, SAME, SAME, SAME]
+ for i, f, s, p in zip(input_sizes, filter_sizes, strides, paddings):
+ yield i, f, s, p
+
+ def _VerifyValues(self, input_size, filter_size, stride, padding):
+ imag = np.random.rand(*input_size).astype(np.float32)
+ filt = np.random.rand(*filter_size).astype(np.float32)
+ strides = [1, stride, stride, 1]
+
+ with self.test_session():
+ with self.test_scope():
+ imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
+ filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
+ feed_dict = {imag_ph: imag, filt_ph: filt}
+ xla_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
+ padding).eval(feed_dict=feed_dict)
+
+ with self.test_session():
+ with ops.device(self.CPU_DEVICE):
+ imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
+ filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
+ feed_dict = {imag_ph: imag, filt_ph: filt}
+ cpu_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
+ padding).eval(feed_dict=feed_dict)
+
+ self.assertAllClose(xla_out, cpu_out)
+
+ # This is disabled because we need a mechanism to set command-line flags,
+ # i.e. an implementation of SetCommandLineOption() below.
+ #
+ # def _VerifyDummy(self, input_size, filter_size, stride, padding):
+ # imag = np.random.rand(*input_size).astype(np.float32)
+ # filt = np.random.rand(*filter_size).astype(np.float32)
+ # strides = [1, stride, stride, 1]
+ #
+ # with self.test_session():
+ # with self.test_scope():
+ # imag_ph = tf.placeholder(tf.float32, shape=input_size)
+ # filt_ph = tf.placeholder(tf.float32, shape=filter_size)
+ # feed_dict = {imag_ph: imag, filt_ph: filt}
+ # SetCommandLineOption(
+ # "tf_tla_depthwise_conv2d_custom_func",
+ # "DummyDepthwiseConv2dKernel")
+ # xla_out = tf.nn.depthwise_conv2d(
+ # imag_ph, filt_ph, strides, padding).eval(feed_dict=feed_dict)
+ # SetCommandLineOption(
+ # "tf_tla_depthwise_conv2d_custom_func", "")
+ #
+ # expected = np.array(range(np.ravel(xla_out).shape[0]), dtype=np.float32)
+ # self.assertAllClose(np.ravel(xla_out), expected)
+
+ def testBasic(self):
+ for i, f, s, p in self.ConfigsToTest():
+ self._VerifyValues(i, f, s, p)
+
+ # Test disabled until _VerifyDummy(), above can be implemented.
+ # def testCustomFunc(self):
+ # if self.has_custom_call:
+ # for i, f, s, p in self.ConfigsToTest():
+ # self._VerifyDummy(i, f, s, p)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc b/tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc
new file mode 100644
index 0000000000..97b71c0228
--- /dev/null
+++ b/tensorflow/compiler/tests/depthwise_conv2d_test_kernel.cc
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int64;
+
+// A dummy implementation that fills the output with 0, 1, 2,...
+// to test the custom call implementation of DepthwiseConv2dNative op.
+// TODO(keveman): Test this after adding a real implementation for the kernel.
+extern "C" void DummyDepthwiseConv2dKernel(float* output, void** inputs) {
+ const int64* output_size = reinterpret_cast<const int64*>(inputs[4]);
+ const int64 total_size =
+ output_size[0] * output_size[1] * output_size[2] * output_size[3];
+ for (int64 i = 0; i < total_size; ++i) {
+ *(output + i) = i;
+ }
+}
diff --git a/tensorflow/compiler/tests/dynamic_stitch_test.py b/tensorflow/compiler/tests/dynamic_stitch_test.py
new file mode 100644
index 0000000000..c109c27abe
--- /dev/null
+++ b/tensorflow/compiler/tests/dynamic_stitch_test.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tf.dynamic_stitch."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.platform import googletest
+
+
+class DynamicStitchTest(XLATestCase):
+
+ def _AssertDynamicStitchResultIs(self, indices, data, expected):
+ with self.test_session() as session:
+ index_placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
+ ]
+ data_placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data
+ ]
+ with self.test_scope():
+ output = data_flow_ops.dynamic_stitch(index_placeholders,
+ data_placeholders)
+
+ feed_dict = {}
+ for placeholder, value in zip(index_placeholders, indices):
+ feed_dict[placeholder] = value
+ for placeholder, value in zip(data_placeholders, data):
+ feed_dict[placeholder] = value
+ result = session.run(output, feed_dict=feed_dict)
+ self.assertAllClose(expected, result, rtol=1e-3)
+
+ def testSimpleEmpty(self):
+ idx1 = np.array([0, 2], dtype=np.int32)
+ idx2 = np.array([[1], [3]], dtype=np.int32)
+ val1 = np.array([[], []], dtype=np.int32)
+ val2 = np.array([[[]], [[]]], dtype=np.int32)
+ self._AssertDynamicStitchResultIs(
+ [idx1, idx2], [val1, val2],
+ expected=np.array([[], [], [], []], np.int32))
+
+ def testSimple1D(self):
+ val1 = np.array([0, 4, 7], dtype=np.int32)
+ val2 = np.array([1, 6, 2, 3, 5], dtype=np.int32)
+ val3 = np.array([0, 40, 70], dtype=np.float32)
+ val4 = np.array([10, 60, 20, 30, 50], dtype=np.float32)
+ expected = np.array([0, 10, 20, 30, 40, 50, 60, 70], dtype=np.float32)
+ self._AssertDynamicStitchResultIs(
+ [val1, val2], [val3, val4], expected=expected)
+
+ def testSimple2D(self):
+ val1 = np.array([0, 4, 7], dtype=np.int32)
+ val2 = np.array([1, 6], dtype=np.int32)
+ val3 = np.array([2, 3, 5], dtype=np.int32)
+ val4 = np.array([[0, 1], [40, 41], [70, 71]], dtype=np.float32)
+ val5 = np.array([[10, 11], [60, 61]], dtype=np.float32)
+ val6 = np.array([[20, 21], [30, 31], [50, 51]], dtype=np.float32)
+ expected = np.array(
+ [[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], [50, 51], [60, 61],
+ [70, 71]],
+ dtype=np.float32)
+ self._AssertDynamicStitchResultIs(
+ [val1, val2, val3], [val4, val5, val6], expected=expected)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/function_test.py b/tensorflow/compiler/tests/function_test.py
new file mode 100644
index 0000000000..40cc7a5d60
--- /dev/null
+++ b/tensorflow/compiler/tests/function_test.py
@@ -0,0 +1,130 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for Tensorflow functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class FunctionTest(XLATestCase):
+
+ def testFunction(self):
+ """Executes a simple TensorFlow function."""
+
+ def APlus2B(a, b):
+ return a + b * 2
+
+ aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
+ bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
+ expected = APlus2B(aval, bval)
+
+ with self.test_session() as sess:
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def Foo(a, b):
+ return APlus2B(a, b)
+
+ a = constant_op.constant(aval, name="a")
+ b = constant_op.constant(bval, name="b")
+ with self.test_scope():
+ call_f = Foo(a, b)
+ result = sess.run(call_f)
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testNestedFunctions(self):
+ """Executes two nested TensorFlow functions."""
+
+ def TimesTwo(x):
+ return x * 2
+
+ def APlus2B(a, b):
+ return a + TimesTwo(b)
+
+ aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
+ bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
+ expected = APlus2B(aval, bval)
+
+ with self.test_session() as sess:
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def Foo(a, b):
+ return APlus2B(a, b)
+
+ a = constant_op.constant(aval, name="a")
+ b = constant_op.constant(bval, name="b")
+ with self.test_scope():
+ call_g = Foo(a, b)
+ result = sess.run(call_g)
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testFunctionMultipleRetvals(self):
+ """Executes a function with multiple return values."""
+
+ # This function will run on the XLA device
+ def Func(a, b):
+ return a + b, a - b
+
+ aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
+ bval = np.array([5, 6, 7, 8]).reshape([2, 2]).astype(np.float32)
+ expected = Func(aval, bval)
+
+ with self.test_session() as sess:
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def Foo(a, b):
+ return Func(a, b)
+
+ a = constant_op.constant(aval, name="a")
+ b = constant_op.constant(bval, name="b")
+ with self.test_scope():
+ call_f = Foo(a, b)
+ result = sess.run(call_f)
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testFunctionsNoInline(self):
+
+ @function.Defun(dtypes.float32, noinline=True)
+ def TimesTwo(x):
+ return x * 2
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def APlus2B(a, b):
+ return a + TimesTwo(b)
+
+ aval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
+ bval = np.array([4, 3, 2, 1]).reshape([2, 2]).astype(np.float32)
+ expected = aval + bval * 2
+
+ with self.test_session() as sess:
+ with self.test_scope():
+ a = array_ops.placeholder(dtypes.float32, name="a")
+ b = array_ops.placeholder(dtypes.float32, name="b")
+ call = APlus2B(a, b)
+ result = sess.run(call, {a: aval, b: bval})
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py
new file mode 100644
index 0000000000..8a568d6d58
--- /dev/null
+++ b/tensorflow/compiler/tests/jit_test.py
@@ -0,0 +1,459 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for JIT compilation on the CPU and GPU devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.compiler import jit
+from tensorflow.core.framework import function_pb2
+from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+jit_scope = jit.experimental_jit_scope
+
+
+def CompiledKernel(fn, *inputs, **kwargs):
+ """Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
+ name = kwargs.pop("name", None)
+ noinline = kwargs.pop("noinline", None)
+
+ @function.Defun(func_name=name, noinline=noinline, compiled=True)
+ def Compiled(*args):
+ return fn(*args)
+
+ return Compiled(*inputs)
+
+
+def RunMetadataLabels(run_metadata):
+ """Returns all labels in run_metadata."""
+ labels = []
+ for dev_stats in run_metadata.step_stats.dev_stats:
+ for node_stats in dev_stats.node_stats:
+ labels.append(node_stats.timeline_label)
+ return labels
+
+
+def InLabels(labels, substr):
+ """Returns true iff one of the labels contains substr."""
+ return any([substr in x for x in labels])
+
+
+def MetadataHasXlaLaunch(run_metadata):
+ """Returns true if there is a _XlaLaunch kernel in run_metadata's timeline."""
+
+ # TODO(phawkins): find a less hacky way to test whether a kernel ran.
+ return InLabels(RunMetadataLabels(run_metadata), "_XlaLaunch")
+
+
+class JitLaunchTest(test.TestCase):
+
+ # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
+ # Verifies that the outputs match and that XLA was invoked. 'fn' must take
+ # the same number of tensors as arguments that are in 'args', and must return
+ # a tuple of output tensors.
+ # If 'require_kernel_launch' is True, then we verify that a _XlaLaunch node
+ # actually ran. However, it is sometimes possible for _XlaLaunch ops to be
+ # constant-folded away, so the check is optional.
+ def _compare(self, fn, args, require_kernel_launch=True, noinline=None):
+ with session_lib.Session() as sess:
+ placeholders = []
+ feeds = {}
+ for arg in args:
+ placeholder = array_ops.placeholder(
+ dtypes.as_dtype(arg.dtype), list(arg.shape))
+ placeholders.append(placeholder)
+ feeds[placeholder] = arg
+
+ compiled_op = CompiledKernel(fn, *placeholders, noinline=noinline)
+ direct_op = fn(*placeholders)
+
+ run_metadata = config_pb2.RunMetadata()
+ compiled = sess.run(compiled_op,
+ feeds,
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ print("Compiled Result {}".format(compiled))
+
+ if require_kernel_launch:
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+
+ direct = sess.run(direct_op, feeds)
+ print("Direct Result {}".format(direct))
+
+ if (isinstance(compiled, (tuple, list)) and
+ (isinstance(direct, (tuple, list)))):
+ for (x, y) in zip(compiled, direct):
+ self.assertAllClose(x, y, rtol=1e-1)
+ else:
+ self.assertAllClose(compiled, direct)
+
+ def testNoOutputs(self):
+ with session_lib.Session() as sess:
+ # Build a function with a single Const node, whose output is ignored.
+ fdef = function_pb2.FunctionDef()
+ fdef.signature.name = "KernelWithNoOutputs"
+ node = node_def_pb2.NodeDef()
+ node.op = "Const"
+ node.name = "ignored"
+ node.attr["dtype"].type = dtypes.int32.as_datatype_enum
+ tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[])
+ node.attr["value"].tensor.CopyFrom(tensor)
+ fdef.node_def.extend([node])
+
+ # Check that calling the result as a compiled kernel doesn't crash.
+ @function.Defun(compiled=True)
+ def KernelWithNoOutputs():
+ return constant_op.constant(100)
+
+ # Hack to override the definition. By accessing .definition, we
+ # force the _DefinedFunction initialized internally. Then, we
+ # replace it's internal FunctionDef proto. We do this hack here
+ # because one typically can't construct KernelWithNoOutputs
+ # function via Defun decorator directly.
+ _ = KernelWithNoOutputs.definition
+ foo = KernelWithNoOutputs
+ foo._definition = fdef
+ call = KernelWithNoOutputs()
+ sess.run(call, {})
+
+ def testAliasing(self):
+ """Regression test for compiled functions that return an aliased buffer.
+
+ XLA returns aliased buffers if outputs are identical. Tests that
+ we handle that case.
+ """
+
+ def AddOnceReturnTwice(x):
+ y = math_ops.add(x, x)
+ return y, y
+
+ # Exercises compling a function (say, Foo) which calls another
+ # function (say, Bar) which is not inlined. When the compiler compiles
+ # Foo, it needs to symbolic execute Bar correctly regardless whether
+ # Bar is inlined or not.
+ #
+ # Tests compiled=True and noinline=True.
+ self._compare(
+ AddOnceReturnTwice, [np.array(
+ [[[0.5, -1.0]]], dtype=np.float32)],
+ noinline=True)
+ # Tests compiled=True and noinline=False.
+ self._compare(
+ AddOnceReturnTwice, [np.array(
+ [[[0.5, -1.0]]], dtype=np.float32)],
+ noinline=False)
+
+ def testOneConstOutput(self):
+ """Test consisting of a single constant return value."""
+
+ def OneConstOutput():
+ return constant_op.constant([-3, 44, 99])
+
+ self._compare(OneConstOutput, [], require_kernel_launch=False)
+
+ def testConstZeroElementOutput(self):
+ """Test consisting of a constant zero element return value."""
+
+ def ConstZeroElementOutput():
+ return array_ops.fill([7, 0], 3.0)
+
+ self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
+
+ def testSomeConstOutputs(self):
+ """Test kernels that return a mixture of const and non-const outputs."""
+
+ def SomeConstOutputs(x):
+ return constant_op.constant(
+ [-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
+
+ self._compare(
+ SomeConstOutputs, [np.array(
+ [[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
+
+ def testInt32Input(self):
+ """Test an int32-typed input.
+
+ On a GPU, int32 tensors will be placed in host memory.
+ """
+
+ def AddToSelf(x):
+ return math_ops.add(x, x)
+
+ self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
+
+ def testMandatoryConstantInput(self):
+ """Tests an operator that has a mandatory-constant shape input."""
+
+ def FillWithFloat(x):
+ return array_ops.fill(x, 9.5)
+
+ self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
+
+ def testMnistForwardFunc(self):
+ """Compute inference function from MNIST beginners tutorial."""
+ batch_size = 16
+ image_size = 28 * 28
+ num_classes = 10
+
+ # Define a TensorFlow function to compute the forward pass.
+ def MnistForward(w, b, x):
+ return nn_ops.softmax(math_ops.matmul(x, w) + b)
+
+ w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
+ b = np.random.random_sample((num_classes)).astype(np.float32)
+ x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
+ self._compare(MnistForward, [w, b, x])
+
+ def testExplicitMarking(self):
+ """Test explicit marking of operators to compile."""
+ batch_size = 16
+ image_size = 28 * 28
+ num_classes = 10
+
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ w = array_ops.placeholder(dtypes.float32)
+ b = array_ops.placeholder(dtypes.float32)
+ with jit_scope():
+ y1 = math_ops.matmul(x, w)
+ y2 = math_ops.add(y1, b)
+ with jit_scope():
+ y = math_ops.square(y2)
+
+ dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
+ db = np.random.random_sample((num_classes)).astype(np.float32)
+ dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
+ with session_lib.Session() as sess:
+ run_metadata = config_pb2.RunMetadata()
+ output = sess.run(y, {x: dx,
+ w: dw,
+ b: db},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+
+ # TODO(phawkins): really we would like to test that there were exactly
+ # two kernel launches. However, we have no reliable way to determine
+ # that.
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+
+ expected = np.square(np.dot(dx, dw) + db)
+ self.assertAllClose(expected, output, rtol=1e-1)
+
+
+class XlaCompilationTest(test.TestCase):
+ """Tests for auto-compilation on CPU/GPU devices."""
+
+ def testReshape(self):
+ """Tests an operator with compile-time constant and non-constant inputs."""
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.placeholder(dtypes.int32)
+ with jit_scope():
+ # Reshape's first argument is non-constant in the JIT, but its second
+ # (shape) argument will be treated as a compile-time constant for
+ # each JIT compilation.
+ # We do not use a tf.const() argument since we want to ensure the
+ # shape is still a run-time argument to the JIT, and not
+ # statically known as part of the JIT compilation's input graph.
+ z = array_ops.reshape(x, y)
+ run_metadata = config_pb2.RunMetadata()
+ out = sess.run(z,
+ {x: np.array([1, 2, 3, 4, 5, 6], np.float32),
+ y: [-1, 3]},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
+
+ def testIgnoredArguments(self):
+ """Tests that JIT computations can ignore formal parameters."""
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.int32)
+ y = array_ops.placeholder(dtypes.int32)
+ with jit_scope():
+ z = math_ops.add(x, x)
+ w = math_ops.add(y, y)
+ # Pulls 'w' into the same compilation via control dependencies.
+ with ops.control_dependencies([w]):
+ n = control_flow_ops.no_op()
+ with ops.control_dependencies([n]):
+ t = math_ops.add(z, z)
+
+ run_metadata = config_pb2.RunMetadata()
+ out = sess.run(t, {x: np.int32(7),
+ y: np.int32(404)},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(28, out)
+
+ def testLoops(self):
+ """Tests that compilation accepts computations containing loops."""
+
+ with self.test_session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ with jit_scope():
+ c = lambda i, _: math_ops.less(i, 5)
+ b = lambda i, x: (i + 1, x * 2.0 + 1.0)
+ _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
+
+ run_metadata = config_pb2.RunMetadata()
+ result = session.run(y, {x: np.float32(2)},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(result, np.float32(95), rtol=1e-1)
+
+ def testCond(self):
+ """Tests that compilation handles switch operators."""
+
+ with self.test_session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.placeholder(dtypes.float32)
+ c = array_ops.placeholder(dtypes.bool)
+ with jit_scope():
+ z = x + 1.0
+ w = control_flow_ops.cond(c, lambda: z, lambda: y)
+ t = math_ops.add(z, w)
+
+ # If JIT compilation chooses to cluster z and t, then execution will
+ # deadlock.
+
+ run_metadata = config_pb2.RunMetadata()
+ result = session.run(t, {x: np.float32(2),
+ y: np.float32(4),
+ c: True},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assert_(MetadataHasXlaLaunch(run_metadata))
+ self.assertAllClose(result, np.float32(6), rtol=1e-1)
+
+ def testNestedFunction(self):
+ g = ops.Graph()
+ with g.as_default():
+
+ @function.Defun(compiled=True)
+ def Bar(x, y):
+ return x + 2 * y
+
+ @function.Defun(compiled=True)
+ def Foo(x):
+ return Bar(x * x, x * x * x)
+
+ @function.Defun()
+ def Entry(x):
+ return Foo(x)
+
+ inp = array_ops.placeholder(dtypes.float32)
+ out = Entry(inp)
+
+ with self.test_session(graph=g, use_gpu=True) as sess:
+ run_metadata = config_pb2.RunMetadata()
+ val = sess.run(out,
+ feed_dict={inp: [2., 10.]},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assertAllClose(val, [20., 2100.])
+
+ def testLoopDeadlock(self):
+ """Regression test for bug that caused deadlocks in graphs with loops."""
+
+ with self.test_session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ with jit_scope():
+ y = x + 1.0
+ c = lambda i, _x, _y: math_ops.less(i, 5)
+ b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
+ _, _, w = control_flow_ops.while_loop(c, b,
+ (constant_op.constant(0), y, x))
+ u = w + y
+ result = session.run(u, {x: np.float32(2)})
+ self.assertAllClose(result, np.float32(63), rtol=1e-1)
+
+ def testGradient(self):
+ """Tests that the backprop function is properly compiled."""
+
+ def _Run(compiled):
+
+ @function.Defun(compiled=compiled)
+ def Forward(x):
+ return math_ops.log(x)
+
+ g = ops.Graph()
+ with g.as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ y = Forward(x)
+ dx, = gradients_impl.gradients(y, [x], 1.0)
+
+ cfg = config_pb2.ConfigProto(graph_options=config_pb2.GraphOptions(
+ optimizer_options=config_pb2.OptimizerOptions(
+ opt_level=config_pb2.OptimizerOptions.L1,
+ do_function_inlining=True)))
+ with session_lib.Session(graph=g, config=cfg) as sess:
+ run_metadata = config_pb2.RunMetadata()
+ dx_val = sess.run(dx,
+ feed_dict={x: 100.},
+ run_metadata=run_metadata,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assertAllClose(dx_val, 0.01)
+ return RunMetadataLabels(run_metadata)
+
+ # SymGrad[f=log(x)](x, dy) = 1/x * dy
+ #
+ # Note: we don't need to compute log(x) for dx due to graph pruning.
+
+ # Do not compile the backprop. We should see one Reciprocal and one Mul.
+ labels = _Run(compiled=False)
+ self.assertFalse(InLabels(labels, "Log"))
+ self.assertTrue(InLabels(labels, "Reciprocal"))
+ self.assertTrue(InLabels(labels, "Mul"))
+ self.assertFalse(InLabels(labels, "_XlaLaunch"))
+
+ # Compile the backprop. One _XlaLaunch.
+ labels = _Run(compiled=True)
+ self.assertFalse(InLabels(labels, "Log"))
+ self.assertFalse(InLabels(labels, "Reciprocal"))
+ self.assertFalse(InLabels(labels, "Mul"))
+ self.assertTrue(InLabels(labels, "_XlaLaunch"))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/lrn_ops_test.py b/tensorflow/compiler/tests/lrn_ops_test.py
new file mode 100644
index 0000000000..5d8d89224d
--- /dev/null
+++ b/tensorflow/compiler/tests/lrn_ops_test.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Local Response Normalization ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.platform import googletest
+
+CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
+
+
+# Local response normalization tests. The forward tests are copied from
+# tensorflow/python/kernel_tests/lrn_op_test.py
+class LRNTest(XLATestCase):
+
+ def _LRN(self, input_image, lrn_depth_radius=5, bias=1.0, alpha=1.0,
+ beta=0.5):
+ """Compute expected result."""
+ output = copy.deepcopy(input_image)
+ batch_size = input_image.shape[0]
+ rows = input_image.shape[1]
+ cols = input_image.shape[2]
+ depth = input_image.shape[3]
+ for b in range(batch_size):
+ for r in range(rows):
+ for c in range(cols):
+ for d in range(depth):
+ begin = max(0, d - lrn_depth_radius)
+ end = min(depth, d + lrn_depth_radius + 1)
+ patch = input_image[b, r, c, begin:end]
+ output[b, r, c, d] /= (
+ np.power(bias + alpha * np.sum(patch * patch), beta))
+ return output
+
+ def _RunAndVerify(self, dtype):
+ with self.test_session():
+ # random shape
+ shape = np.random.randint(1, 16, size=4)
+ # Make depth at least 2 to make it meaningful
+ shape[3] += 1
+ p = array_ops.placeholder(dtype, shape=shape)
+ # random depth_radius, bias, alpha, beta
+ lrn_depth_radius = np.random.randint(1, shape[3])
+ bias = 1.0 + np.random.rand()
+ alpha = 2.0 * np.random.rand()
+ beta = 2.0 * np.random.rand()
+ with self.test_scope():
+ lrn_t = nn.local_response_normalization(
+ p,
+ name="lrn",
+ depth_radius=lrn_depth_radius,
+ bias=bias,
+ alpha=alpha,
+ beta=beta)
+ params = {p: np.random.rand(*shape).astype("f")}
+ result = lrn_t.eval(feed_dict=params)
+ expected = self._LRN(
+ params[p],
+ lrn_depth_radius=lrn_depth_radius,
+ bias=bias,
+ alpha=alpha,
+ beta=beta)
+ err = np.amax(np.abs(result - expected))
+ print("LRN error for bias ", bias, "alpha ", alpha, " beta ", beta, " is ",
+ err)
+ if dtype == dtypes.float32:
+ self.assertTrue(err < 1e-4)
+ else:
+ self.assertTrue(err < 1e-2)
+ self.assertShapeEqual(expected, lrn_t)
+
+ def testCompute(self):
+ for _ in range(2):
+ self._RunAndVerify(dtypes.float32)
+
+ def testLrnGrad(self):
+ # Test for LRNGrad that compares against the CPU implementation.
+ shape = [1, 2, 3, 4]
+ total_size = np.prod(shape)
+ in_image_vals = np.arange(1, total_size + 1, dtype=np.float32)
+ out_image_vals = np.arange(1, total_size + 1, dtype=np.float32)
+ out_grads_vals = np.arange(1, total_size + 1, dtype=np.float32)
+ depth_radius = np.random.randint(1, shape[3])
+ bias = 1.0 + np.random.rand()
+ alpha = 1.0 * np.random.rand()
+ beta = 1.0 * np.random.rand()
+
+ with self.test_session():
+ in_image = constant_op.constant(in_image_vals, shape=shape)
+ out_image = constant_op.constant(out_image_vals, shape=shape)
+ out_grads = constant_op.constant(out_grads_vals, shape=shape)
+ with ops.device(CPU_DEVICE):
+ expected = gen_nn_ops._lrn_grad(out_grads, in_image, out_image,
+ depth_radius, bias, alpha, beta)
+ with self.test_scope():
+ actual = gen_nn_ops._lrn_grad(out_grads, in_image, out_image,
+ depth_radius, bias, alpha, beta)
+ expected_val = expected.eval()
+ actual_val = actual.eval()
+ self.assertAllClose(actual_val, expected_val, rtol=1e-3)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/lstm.py b/tensorflow/compiler/tests/lstm.py
new file mode 100644
index 0000000000..18166f51bf
--- /dev/null
+++ b/tensorflow/compiler/tests/lstm.py
@@ -0,0 +1,158 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A simple LSTM layer with benchmarks.
+
+This sets up a simple LSTM (Long Short Term Memory) layer, unrolled to a fixed
+length sequence. The only deviation from standard LSTM cells is that
+activations are clipped, inspired by the GNMT machine translation model.
+The GNMT paper has more details: https://arxiv.org/abs/1609.08144
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+
+
+def Clip(x):
+ """Clips x to the range [-1., 1.]."""
+ return math_ops.maximum(math_ops.minimum(x, 1.), -1.)
+
+
+def LSTMCellWeightsShape(num_inputs, num_nodes):
+ """Returns the shape of the weights for a single LSTM cell."""
+ # Dimension 0 accounts for combining x with the previous m state.
+ # Dimension 1 accounts for the in value and the (in, forget, out) gates.
+ return [num_inputs + num_nodes, 4 * num_nodes]
+
+
+def LSTMCell(weights, m_prev, c_prev, x, pad):
+ """Unrolls a single LSTM cell with clipped activations forward by one step.
+
+ Args:
+ weights: Weight matrix with shape LSTMCellWeightsShape.
+ m_prev: Previous m states with shape [batch_size, num_nodes].
+ c_prev: Previous c states with shape [batch_size, num_nodes].
+ x: Input with shape [batch_size, num_inputs].
+ pad: Padding with shape [batch_size, 1]. Each padding value is either
+ 0 or 1, where 1 indicates padding; i.e. the input is shorter than the
+ sequence length, and the (m, c) states should simply be passed through
+ from the previous states.
+ Returns:
+ The next (m, c) states, each with shape [batch_size, num_nodes].
+ """
+ # Apply weights to the input and previous hidden state.
+ # The matmul here is the "big" operation.
+ xm = array_ops.concat_v2([x, m_prev], 1)
+ xmw = math_ops.matmul(xm, weights)
+
+ # Element-wise ops for the standard LSTM cell, with clipped activations.
+ # XLA can fuse these operations into a single loop.
+ in_value, in_gate, forget_gate, out_gate = array_ops.split(
+ value=xmw, num_or_size_splits=4, axis=1)
+ in_value = math_ops.tanh(in_value)
+ in_gate = math_ops.sigmoid(in_gate)
+ forget_gate = math_ops.sigmoid(forget_gate)
+ out_gate = math_ops.sigmoid(out_gate)
+ c_next = Clip(Clip(forget_gate * c_prev) + Clip(in_gate * in_value))
+ m_next = Clip(out_gate * c_next)
+
+ # Account for padding.
+ c_next = c_prev * pad + c_next * (1.0 - pad)
+ m_next = m_prev * pad + m_next * (1.0 - pad)
+
+ return m_next, c_next
+
+
+def LSTMLayer(cell_name, weights, m, c, x_seq, pad_seq):
+ """Unrolls a layer of LSTM cells forward by the sequence length.
+
+ The sequence length is determined by the length of x_seq and pad_seq, which
+ must be the same.
+
+ Args:
+ cell_name: Base name of each cell.
+ weights: Weight matrix with shape LSTMCellWeightsShape.
+ m: Initial m states with shape [batch_size, num_nodes].
+ c: Initial c states with shape [batch_size, num_nodes].
+ x_seq: List of inputs, each with shape [batch_size, num_inputs].
+ The length of the list is the sequence length.
+ pad_seq: List of paddings, each with shape [batch_size, 1].
+ The length of the list is the sequence length.
+ Each padding value is either 0 or 1, where 1 indicates padding;
+ i.e. the input is shorter than the sequence length.
+ Returns:
+ List of per-sequence-step outputs, each with shape [batch_size, num_nodes].
+ Raises:
+ ValueError: If len(x_seq) != len(pad_seq).
+ """
+ if len(x_seq) != len(pad_seq):
+ raise ValueError('length of x_seq(%d) != pad_seq(%d)' %
+ (len(x_seq), len(pad_seq)))
+ out_seq = []
+ for seq in range(len(x_seq)):
+ with ops.name_scope('%s_%d' % (cell_name, seq)):
+ m, c = LSTMCell(weights, m, c, x_seq[seq], pad_seq[seq])
+ out_seq.append(array_ops.identity(m, name='out'))
+ return out_seq
+
+
+def RandomVar(shape, name=None):
+ """Returns a variable of the given shape initialized to random values."""
+ return variables.Variable(
+ random_ops.random_uniform(shape), dtype=dtypes.float32, name=name)
+
+
+def RandomInputs(batch_size, seq_length, num_inputs):
+ """Returns randomly initialized (x_seq, pad_seq) sequences."""
+ x_seq = []
+ pad_seq = []
+ with ops.name_scope('inputs'):
+ for seq in range(seq_length):
+ x_seq.append(RandomVar([batch_size, num_inputs], name='x_seq_%d' % seq))
+ # Real padding values are always a sequence of 0 followed by a
+ # sequence of 1, but random values are fine for benchmarking.
+ pad_seq.append(RandomVar([batch_size, 1], name='pad_seq_%d' % seq))
+ return x_seq, pad_seq
+
+
+def BuildLSTMLayer(batch_size, seq_length, num_inputs, num_nodes):
+ """Builds a single LSTM layer with random weights and inputs.
+
+ Args:
+ batch_size: Inputs are fed in batches of this size.
+ seq_length: The sequence length to unroll the LSTM layer.
+ num_inputs: Dimension of inputs that are fed into each LSTM cell.
+ num_nodes: The number of nodes in each LSTM cell.
+
+ Returns:
+ (out_seq, weights) pair. The out_seq is a list of per-sequence-step
+ outputs, each with shape [batch_size, num_nodes]. The weights are a list of
+ weight variables that may be trained.
+ """
+ weights = RandomVar(
+ LSTMCellWeightsShape(num_inputs, num_nodes), name='weights')
+ m = array_ops.zeros([batch_size, num_nodes], name='init_m')
+ c = array_ops.zeros([batch_size, num_nodes], name='init_c')
+ x_seq, pad_seq = RandomInputs(batch_size, seq_length, num_inputs)
+
+ out_seq = LSTMLayer('lstm', weights, m, c, x_seq, pad_seq)
+ return out_seq, [weights]
diff --git a/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
new file mode 100644
index 0000000000..c46e65f71a
--- /dev/null
+++ b/tensorflow/compiler/tests/lstm_layer_inference.config.pbtxt
@@ -0,0 +1,20 @@
+# Text form of tensorflow.tfcompile.Config proto.
+feed{ id{node_name:"inputs/x_seq_0/read"} shape{dim{size:128}dim{size:1024}} }
+feed{ id{node_name:"inputs/x_seq_1/read"} shape{dim{size:128}dim{size:1024}} }
+feed{ id{node_name:"inputs/x_seq_2/read"} shape{dim{size:128}dim{size:1024}} }
+feed{ id{node_name:"inputs/x_seq_3/read"} shape{dim{size:128}dim{size:1024}} }
+feed{ id{node_name:"inputs/x_seq_4/read"} shape{dim{size:128}dim{size:1024}} }
+feed{ id{node_name:"inputs/pad_seq_0/read"} shape{dim{size:128}dim{size:1}} }
+feed{ id{node_name:"inputs/pad_seq_1/read"} shape{dim{size:128}dim{size:1}} }
+feed{ id{node_name:"inputs/pad_seq_2/read"} shape{dim{size:128}dim{size:1}} }
+feed{ id{node_name:"inputs/pad_seq_3/read"} shape{dim{size:128}dim{size:1}} }
+feed{ id{node_name:"inputs/pad_seq_4/read"} shape{dim{size:128}dim{size:1}} }
+feed{ id{node_name:"weights/read"} shape{dim{size:2048}dim{size:4096}} }
+feed{ id{node_name:"init_c"} shape{dim{size:128}dim{size:1024}} }
+feed{ id{node_name:"init_m"} shape{dim{size:128}dim{size:1024}} }
+
+fetch{ id{node_name:"lstm_0/out"} }
+fetch{ id{node_name:"lstm_1/out"} }
+fetch{ id{node_name:"lstm_2/out"} }
+fetch{ id{node_name:"lstm_3/out"} }
+fetch{ id{node_name:"lstm_4/out"} }
diff --git a/tensorflow/compiler/tests/lstm_layer_inference.pbtxt b/tensorflow/compiler/tests/lstm_layer_inference.pbtxt
new file mode 100644
index 0000000000..649f4a8ff1
--- /dev/null
+++ b/tensorflow/compiler/tests/lstm_layer_inference.pbtxt
@@ -0,0 +1,5828 @@
+# Generated by running lstm_test, setting --dump_graph_dir.
+
+node {
+ name: "random_uniform/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\000\010\000\000\000\020\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "random_uniform/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "random_uniform/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "random_uniform/RandomUniform"
+ op: "RandomUniform"
+ input: "random_uniform/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "random_uniform/sub"
+ op: "Sub"
+ input: "random_uniform/max"
+ input: "random_uniform/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "random_uniform/mul"
+ op: "Mul"
+ input: "random_uniform/RandomUniform"
+ input: "random_uniform/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "random_uniform"
+ op: "Add"
+ input: "random_uniform/mul"
+ input: "random_uniform/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "weights"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 2048
+ }
+ dim {
+ size: 4096
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "weights/Assign"
+ op: "Assign"
+ input: "weights"
+ input: "random_uniform"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@weights"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "weights/read"
+ op: "Identity"
+ input: "weights"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@weights"
+ }
+ }
+ }
+}
+node {
+ name: "init_m"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "init_c"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\000\004\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform/sub"
+ op: "Sub"
+ input: "inputs/random_uniform/max"
+ input: "inputs/random_uniform/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform/mul"
+ op: "Mul"
+ input: "inputs/random_uniform/RandomUniform"
+ input: "inputs/random_uniform/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform"
+ op: "Add"
+ input: "inputs/random_uniform/mul"
+ input: "inputs/random_uniform/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_0"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_0/Assign"
+ op: "Assign"
+ input: "inputs/x_seq_0"
+ input: "inputs/random_uniform"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_0"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_0/read"
+ op: "Identity"
+ input: "inputs/x_seq_0"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_0"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\001\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_1/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_1/max"
+ input: "inputs/random_uniform_1/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_1/RandomUniform"
+ input: "inputs/random_uniform_1/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_1"
+ op: "Add"
+ input: "inputs/random_uniform_1/mul"
+ input: "inputs/random_uniform_1/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_0"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_0/Assign"
+ op: "Assign"
+ input: "inputs/pad_seq_0"
+ input: "inputs/random_uniform_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_0"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_0/read"
+ op: "Identity"
+ input: "inputs/pad_seq_0"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_0"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\000\004\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_2/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_2/max"
+ input: "inputs/random_uniform_2/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_2/RandomUniform"
+ input: "inputs/random_uniform_2/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_2"
+ op: "Add"
+ input: "inputs/random_uniform_2/mul"
+ input: "inputs/random_uniform_2/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_1"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_1/Assign"
+ op: "Assign"
+ input: "inputs/x_seq_1"
+ input: "inputs/random_uniform_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_1"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_1/read"
+ op: "Identity"
+ input: "inputs/x_seq_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_1"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\001\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_3/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_3/max"
+ input: "inputs/random_uniform_3/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_3/RandomUniform"
+ input: "inputs/random_uniform_3/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_3"
+ op: "Add"
+ input: "inputs/random_uniform_3/mul"
+ input: "inputs/random_uniform_3/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_1"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_1/Assign"
+ op: "Assign"
+ input: "inputs/pad_seq_1"
+ input: "inputs/random_uniform_3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_1"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_1/read"
+ op: "Identity"
+ input: "inputs/pad_seq_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_1"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\000\004\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_4/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_4/max"
+ input: "inputs/random_uniform_4/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_4/RandomUniform"
+ input: "inputs/random_uniform_4/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_4"
+ op: "Add"
+ input: "inputs/random_uniform_4/mul"
+ input: "inputs/random_uniform_4/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_2"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_2/Assign"
+ op: "Assign"
+ input: "inputs/x_seq_2"
+ input: "inputs/random_uniform_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_2"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_2/read"
+ op: "Identity"
+ input: "inputs/x_seq_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_2"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\001\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_5/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_5/max"
+ input: "inputs/random_uniform_5/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_5/RandomUniform"
+ input: "inputs/random_uniform_5/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_5"
+ op: "Add"
+ input: "inputs/random_uniform_5/mul"
+ input: "inputs/random_uniform_5/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_2"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_2/Assign"
+ op: "Assign"
+ input: "inputs/pad_seq_2"
+ input: "inputs/random_uniform_5"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_2"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_2/read"
+ op: "Identity"
+ input: "inputs/pad_seq_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_2"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\000\004\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_6/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_6/max"
+ input: "inputs/random_uniform_6/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_6/RandomUniform"
+ input: "inputs/random_uniform_6/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_6"
+ op: "Add"
+ input: "inputs/random_uniform_6/mul"
+ input: "inputs/random_uniform_6/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_3"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_3/Assign"
+ op: "Assign"
+ input: "inputs/x_seq_3"
+ input: "inputs/random_uniform_6"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_3"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_3/read"
+ op: "Identity"
+ input: "inputs/x_seq_3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_3"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\001\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_7/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_7/max"
+ input: "inputs/random_uniform_7/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_7/RandomUniform"
+ input: "inputs/random_uniform_7/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_7"
+ op: "Add"
+ input: "inputs/random_uniform_7/mul"
+ input: "inputs/random_uniform_7/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_3"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_3/Assign"
+ op: "Assign"
+ input: "inputs/pad_seq_3"
+ input: "inputs/random_uniform_7"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_3"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_3/read"
+ op: "Identity"
+ input: "inputs/pad_seq_3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_3"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\000\004\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_8/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_8/max"
+ input: "inputs/random_uniform_8/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_8/RandomUniform"
+ input: "inputs/random_uniform_8/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_8"
+ op: "Add"
+ input: "inputs/random_uniform_8/mul"
+ input: "inputs/random_uniform_8/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_4"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1024
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_4/Assign"
+ op: "Assign"
+ input: "inputs/x_seq_4"
+ input: "inputs/random_uniform_8"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_4"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/x_seq_4/read"
+ op: "Identity"
+ input: "inputs/x_seq_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/x_seq_4"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9/shape"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: "\200\000\000\000\001\000\000\000"
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9/min"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9/max"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9/RandomUniform"
+ op: "RandomUniform"
+ input: "inputs/random_uniform_9/shape"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "seed"
+ value {
+ i: 0
+ }
+ }
+ attr {
+ key: "seed2"
+ value {
+ i: 0
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9/sub"
+ op: "Sub"
+ input: "inputs/random_uniform_9/max"
+ input: "inputs/random_uniform_9/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9/mul"
+ op: "Mul"
+ input: "inputs/random_uniform_9/RandomUniform"
+ input: "inputs/random_uniform_9/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/random_uniform_9"
+ op: "Add"
+ input: "inputs/random_uniform_9/mul"
+ input: "inputs/random_uniform_9/min"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_4"
+ op: "Variable"
+ device: "/device:GPU:*"
+ attr {
+ key: "container"
+ value {
+ s: ""
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 128
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ attr {
+ key: "shared_name"
+ value {
+ s: ""
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_4/Assign"
+ op: "Assign"
+ input: "inputs/pad_seq_4"
+ input: "inputs/random_uniform_9"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_4"
+ }
+ }
+ }
+ attr {
+ key: "use_locking"
+ value {
+ b: true
+ }
+ }
+ attr {
+ key: "validate_shape"
+ value {
+ b: true
+ }
+ }
+}
+node {
+ name: "inputs/pad_seq_4/read"
+ op: "Identity"
+ input: "inputs/pad_seq_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@inputs/pad_seq_4"
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/concat/axis"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/concat"
+ op: "ConcatV2"
+ input: "inputs/x_seq_0/read"
+ input: "init_m"
+ input: "lstm_0/concat/axis"
+ device: "/device:GPU:*"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "lstm_0/MatMul"
+ op: "MatMul"
+ input: "lstm_0/concat"
+ input: "weights/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "lstm_0/split/split_dim"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/split"
+ op: "Split"
+ input: "lstm_0/split/split_dim"
+ input: "lstm_0/MatMul"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "num_split"
+ value {
+ i: 4
+ }
+ }
+}
+node {
+ name: "lstm_0/Tanh"
+ op: "Tanh"
+ input: "lstm_0/split"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Sigmoid"
+ op: "Sigmoid"
+ input: "lstm_0/split:1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Sigmoid_1"
+ op: "Sigmoid"
+ input: "lstm_0/split:2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Sigmoid_2"
+ op: "Sigmoid"
+ input: "lstm_0/split:3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul"
+ op: "Mul"
+ input: "lstm_0/Sigmoid_1"
+ input: "init_c"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum"
+ op: "Minimum"
+ input: "lstm_0/mul"
+ input: "lstm_0/Minimum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum"
+ op: "Maximum"
+ input: "lstm_0/Minimum"
+ input: "lstm_0/Maximum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul_1"
+ op: "Mul"
+ input: "lstm_0/Sigmoid"
+ input: "lstm_0/Tanh"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum_1"
+ op: "Minimum"
+ input: "lstm_0/mul_1"
+ input: "lstm_0/Minimum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum_1"
+ op: "Maximum"
+ input: "lstm_0/Minimum_1"
+ input: "lstm_0/Maximum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/add"
+ op: "Add"
+ input: "lstm_0/Maximum"
+ input: "lstm_0/Maximum_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum_2"
+ op: "Minimum"
+ input: "lstm_0/add"
+ input: "lstm_0/Minimum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum_2"
+ op: "Maximum"
+ input: "lstm_0/Minimum_2"
+ input: "lstm_0/Maximum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul_2"
+ op: "Mul"
+ input: "lstm_0/Sigmoid_2"
+ input: "lstm_0/Maximum_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Minimum_3"
+ op: "Minimum"
+ input: "lstm_0/mul_2"
+ input: "lstm_0/Minimum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/Maximum_3"
+ op: "Maximum"
+ input: "lstm_0/Minimum_3"
+ input: "lstm_0/Maximum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul_3"
+ op: "Mul"
+ input: "init_c"
+ input: "inputs/pad_seq_0/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/sub/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/sub"
+ op: "Sub"
+ input: "lstm_0/sub/x"
+ input: "inputs/pad_seq_0/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul_4"
+ op: "Mul"
+ input: "lstm_0/Maximum_2"
+ input: "lstm_0/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/add_1"
+ op: "Add"
+ input: "lstm_0/mul_3"
+ input: "lstm_0/mul_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul_5"
+ op: "Mul"
+ input: "init_m"
+ input: "inputs/pad_seq_0/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/sub_1/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_0/sub_1"
+ op: "Sub"
+ input: "lstm_0/sub_1/x"
+ input: "inputs/pad_seq_0/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/mul_6"
+ op: "Mul"
+ input: "lstm_0/Maximum_3"
+ input: "lstm_0/sub_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/add_2"
+ op: "Add"
+ input: "lstm_0/mul_5"
+ input: "lstm_0/mul_6"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_0/out"
+ op: "Identity"
+ input: "lstm_0/add_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/concat/axis"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/concat"
+ op: "ConcatV2"
+ input: "inputs/x_seq_1/read"
+ input: "lstm_0/add_2"
+ input: "lstm_1/concat/axis"
+ device: "/device:GPU:*"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "lstm_1/MatMul"
+ op: "MatMul"
+ input: "lstm_1/concat"
+ input: "weights/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "lstm_1/split/split_dim"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/split"
+ op: "Split"
+ input: "lstm_1/split/split_dim"
+ input: "lstm_1/MatMul"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "num_split"
+ value {
+ i: 4
+ }
+ }
+}
+node {
+ name: "lstm_1/Tanh"
+ op: "Tanh"
+ input: "lstm_1/split"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Sigmoid"
+ op: "Sigmoid"
+ input: "lstm_1/split:1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Sigmoid_1"
+ op: "Sigmoid"
+ input: "lstm_1/split:2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Sigmoid_2"
+ op: "Sigmoid"
+ input: "lstm_1/split:3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul"
+ op: "Mul"
+ input: "lstm_1/Sigmoid_1"
+ input: "lstm_0/add_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum"
+ op: "Minimum"
+ input: "lstm_1/mul"
+ input: "lstm_1/Minimum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum"
+ op: "Maximum"
+ input: "lstm_1/Minimum"
+ input: "lstm_1/Maximum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul_1"
+ op: "Mul"
+ input: "lstm_1/Sigmoid"
+ input: "lstm_1/Tanh"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum_1"
+ op: "Minimum"
+ input: "lstm_1/mul_1"
+ input: "lstm_1/Minimum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum_1"
+ op: "Maximum"
+ input: "lstm_1/Minimum_1"
+ input: "lstm_1/Maximum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/add"
+ op: "Add"
+ input: "lstm_1/Maximum"
+ input: "lstm_1/Maximum_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum_2"
+ op: "Minimum"
+ input: "lstm_1/add"
+ input: "lstm_1/Minimum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum_2"
+ op: "Maximum"
+ input: "lstm_1/Minimum_2"
+ input: "lstm_1/Maximum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul_2"
+ op: "Mul"
+ input: "lstm_1/Sigmoid_2"
+ input: "lstm_1/Maximum_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Minimum_3"
+ op: "Minimum"
+ input: "lstm_1/mul_2"
+ input: "lstm_1/Minimum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/Maximum_3"
+ op: "Maximum"
+ input: "lstm_1/Minimum_3"
+ input: "lstm_1/Maximum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul_3"
+ op: "Mul"
+ input: "lstm_0/add_1"
+ input: "inputs/pad_seq_1/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/sub/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/sub"
+ op: "Sub"
+ input: "lstm_1/sub/x"
+ input: "inputs/pad_seq_1/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul_4"
+ op: "Mul"
+ input: "lstm_1/Maximum_2"
+ input: "lstm_1/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/add_1"
+ op: "Add"
+ input: "lstm_1/mul_3"
+ input: "lstm_1/mul_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul_5"
+ op: "Mul"
+ input: "lstm_0/add_2"
+ input: "inputs/pad_seq_1/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/sub_1/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_1/sub_1"
+ op: "Sub"
+ input: "lstm_1/sub_1/x"
+ input: "inputs/pad_seq_1/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/mul_6"
+ op: "Mul"
+ input: "lstm_1/Maximum_3"
+ input: "lstm_1/sub_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/add_2"
+ op: "Add"
+ input: "lstm_1/mul_5"
+ input: "lstm_1/mul_6"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_1/out"
+ op: "Identity"
+ input: "lstm_1/add_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/concat/axis"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/concat"
+ op: "ConcatV2"
+ input: "inputs/x_seq_2/read"
+ input: "lstm_1/add_2"
+ input: "lstm_2/concat/axis"
+ device: "/device:GPU:*"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "lstm_2/MatMul"
+ op: "MatMul"
+ input: "lstm_2/concat"
+ input: "weights/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "lstm_2/split/split_dim"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/split"
+ op: "Split"
+ input: "lstm_2/split/split_dim"
+ input: "lstm_2/MatMul"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "num_split"
+ value {
+ i: 4
+ }
+ }
+}
+node {
+ name: "lstm_2/Tanh"
+ op: "Tanh"
+ input: "lstm_2/split"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Sigmoid"
+ op: "Sigmoid"
+ input: "lstm_2/split:1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Sigmoid_1"
+ op: "Sigmoid"
+ input: "lstm_2/split:2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Sigmoid_2"
+ op: "Sigmoid"
+ input: "lstm_2/split:3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul"
+ op: "Mul"
+ input: "lstm_2/Sigmoid_1"
+ input: "lstm_1/add_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum"
+ op: "Minimum"
+ input: "lstm_2/mul"
+ input: "lstm_2/Minimum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum"
+ op: "Maximum"
+ input: "lstm_2/Minimum"
+ input: "lstm_2/Maximum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul_1"
+ op: "Mul"
+ input: "lstm_2/Sigmoid"
+ input: "lstm_2/Tanh"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum_1"
+ op: "Minimum"
+ input: "lstm_2/mul_1"
+ input: "lstm_2/Minimum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum_1"
+ op: "Maximum"
+ input: "lstm_2/Minimum_1"
+ input: "lstm_2/Maximum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/add"
+ op: "Add"
+ input: "lstm_2/Maximum"
+ input: "lstm_2/Maximum_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum_2"
+ op: "Minimum"
+ input: "lstm_2/add"
+ input: "lstm_2/Minimum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum_2"
+ op: "Maximum"
+ input: "lstm_2/Minimum_2"
+ input: "lstm_2/Maximum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul_2"
+ op: "Mul"
+ input: "lstm_2/Sigmoid_2"
+ input: "lstm_2/Maximum_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Minimum_3"
+ op: "Minimum"
+ input: "lstm_2/mul_2"
+ input: "lstm_2/Minimum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/Maximum_3"
+ op: "Maximum"
+ input: "lstm_2/Minimum_3"
+ input: "lstm_2/Maximum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul_3"
+ op: "Mul"
+ input: "lstm_1/add_1"
+ input: "inputs/pad_seq_2/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/sub/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/sub"
+ op: "Sub"
+ input: "lstm_2/sub/x"
+ input: "inputs/pad_seq_2/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul_4"
+ op: "Mul"
+ input: "lstm_2/Maximum_2"
+ input: "lstm_2/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/add_1"
+ op: "Add"
+ input: "lstm_2/mul_3"
+ input: "lstm_2/mul_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul_5"
+ op: "Mul"
+ input: "lstm_1/add_2"
+ input: "inputs/pad_seq_2/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/sub_1/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_2/sub_1"
+ op: "Sub"
+ input: "lstm_2/sub_1/x"
+ input: "inputs/pad_seq_2/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/mul_6"
+ op: "Mul"
+ input: "lstm_2/Maximum_3"
+ input: "lstm_2/sub_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/add_2"
+ op: "Add"
+ input: "lstm_2/mul_5"
+ input: "lstm_2/mul_6"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_2/out"
+ op: "Identity"
+ input: "lstm_2/add_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/concat/axis"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/concat"
+ op: "ConcatV2"
+ input: "inputs/x_seq_3/read"
+ input: "lstm_2/add_2"
+ input: "lstm_3/concat/axis"
+ device: "/device:GPU:*"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "lstm_3/MatMul"
+ op: "MatMul"
+ input: "lstm_3/concat"
+ input: "weights/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "lstm_3/split/split_dim"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/split"
+ op: "Split"
+ input: "lstm_3/split/split_dim"
+ input: "lstm_3/MatMul"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "num_split"
+ value {
+ i: 4
+ }
+ }
+}
+node {
+ name: "lstm_3/Tanh"
+ op: "Tanh"
+ input: "lstm_3/split"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Sigmoid"
+ op: "Sigmoid"
+ input: "lstm_3/split:1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Sigmoid_1"
+ op: "Sigmoid"
+ input: "lstm_3/split:2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Sigmoid_2"
+ op: "Sigmoid"
+ input: "lstm_3/split:3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul"
+ op: "Mul"
+ input: "lstm_3/Sigmoid_1"
+ input: "lstm_2/add_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum"
+ op: "Minimum"
+ input: "lstm_3/mul"
+ input: "lstm_3/Minimum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum"
+ op: "Maximum"
+ input: "lstm_3/Minimum"
+ input: "lstm_3/Maximum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul_1"
+ op: "Mul"
+ input: "lstm_3/Sigmoid"
+ input: "lstm_3/Tanh"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum_1"
+ op: "Minimum"
+ input: "lstm_3/mul_1"
+ input: "lstm_3/Minimum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum_1"
+ op: "Maximum"
+ input: "lstm_3/Minimum_1"
+ input: "lstm_3/Maximum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/add"
+ op: "Add"
+ input: "lstm_3/Maximum"
+ input: "lstm_3/Maximum_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum_2"
+ op: "Minimum"
+ input: "lstm_3/add"
+ input: "lstm_3/Minimum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum_2"
+ op: "Maximum"
+ input: "lstm_3/Minimum_2"
+ input: "lstm_3/Maximum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul_2"
+ op: "Mul"
+ input: "lstm_3/Sigmoid_2"
+ input: "lstm_3/Maximum_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Minimum_3"
+ op: "Minimum"
+ input: "lstm_3/mul_2"
+ input: "lstm_3/Minimum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/Maximum_3"
+ op: "Maximum"
+ input: "lstm_3/Minimum_3"
+ input: "lstm_3/Maximum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul_3"
+ op: "Mul"
+ input: "lstm_2/add_1"
+ input: "inputs/pad_seq_3/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/sub/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/sub"
+ op: "Sub"
+ input: "lstm_3/sub/x"
+ input: "inputs/pad_seq_3/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul_4"
+ op: "Mul"
+ input: "lstm_3/Maximum_2"
+ input: "lstm_3/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/add_1"
+ op: "Add"
+ input: "lstm_3/mul_3"
+ input: "lstm_3/mul_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul_5"
+ op: "Mul"
+ input: "lstm_2/add_2"
+ input: "inputs/pad_seq_3/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/sub_1/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_3/sub_1"
+ op: "Sub"
+ input: "lstm_3/sub_1/x"
+ input: "inputs/pad_seq_3/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/mul_6"
+ op: "Mul"
+ input: "lstm_3/Maximum_3"
+ input: "lstm_3/sub_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/add_2"
+ op: "Add"
+ input: "lstm_3/mul_5"
+ input: "lstm_3/mul_6"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_3/out"
+ op: "Identity"
+ input: "lstm_3/add_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/concat/axis"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/concat"
+ op: "ConcatV2"
+ input: "inputs/x_seq_4/read"
+ input: "lstm_3/add_2"
+ input: "lstm_4/concat/axis"
+ device: "/device:GPU:*"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "lstm_4/MatMul"
+ op: "MatMul"
+ input: "lstm_4/concat"
+ input: "weights/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+}
+node {
+ name: "lstm_4/split/split_dim"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/split"
+ op: "Split"
+ input: "lstm_4/split/split_dim"
+ input: "lstm_4/MatMul"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "num_split"
+ value {
+ i: 4
+ }
+ }
+}
+node {
+ name: "lstm_4/Tanh"
+ op: "Tanh"
+ input: "lstm_4/split"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Sigmoid"
+ op: "Sigmoid"
+ input: "lstm_4/split:1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Sigmoid_1"
+ op: "Sigmoid"
+ input: "lstm_4/split:2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Sigmoid_2"
+ op: "Sigmoid"
+ input: "lstm_4/split:3"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul"
+ op: "Mul"
+ input: "lstm_4/Sigmoid_1"
+ input: "lstm_3/add_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum"
+ op: "Minimum"
+ input: "lstm_4/mul"
+ input: "lstm_4/Minimum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum"
+ op: "Maximum"
+ input: "lstm_4/Minimum"
+ input: "lstm_4/Maximum/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul_1"
+ op: "Mul"
+ input: "lstm_4/Sigmoid"
+ input: "lstm_4/Tanh"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum_1"
+ op: "Minimum"
+ input: "lstm_4/mul_1"
+ input: "lstm_4/Minimum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum_1/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum_1"
+ op: "Maximum"
+ input: "lstm_4/Minimum_1"
+ input: "lstm_4/Maximum_1/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/add"
+ op: "Add"
+ input: "lstm_4/Maximum"
+ input: "lstm_4/Maximum_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum_2"
+ op: "Minimum"
+ input: "lstm_4/add"
+ input: "lstm_4/Minimum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum_2/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum_2"
+ op: "Maximum"
+ input: "lstm_4/Minimum_2"
+ input: "lstm_4/Maximum_2/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul_2"
+ op: "Mul"
+ input: "lstm_4/Sigmoid_2"
+ input: "lstm_4/Maximum_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Minimum_3"
+ op: "Minimum"
+ input: "lstm_4/mul_2"
+ input: "lstm_4/Minimum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum_3/y"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: -1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/Maximum_3"
+ op: "Maximum"
+ input: "lstm_4/Minimum_3"
+ input: "lstm_4/Maximum_3/y"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul_3"
+ op: "Mul"
+ input: "lstm_3/add_1"
+ input: "inputs/pad_seq_4/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/sub/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/sub"
+ op: "Sub"
+ input: "lstm_4/sub/x"
+ input: "inputs/pad_seq_4/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul_4"
+ op: "Mul"
+ input: "lstm_4/Maximum_2"
+ input: "lstm_4/sub"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/add_1"
+ op: "Add"
+ input: "lstm_4/mul_3"
+ input: "lstm_4/mul_4"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul_5"
+ op: "Mul"
+ input: "lstm_3/add_2"
+ input: "inputs/pad_seq_4/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/sub_1/x"
+ op: "Const"
+ device: "/device:GPU:*"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "lstm_4/sub_1"
+ op: "Sub"
+ input: "lstm_4/sub_1/x"
+ input: "inputs/pad_seq_4/read"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/mul_6"
+ op: "Mul"
+ input: "lstm_4/Maximum_3"
+ input: "lstm_4/sub_1"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/add_2"
+ op: "Add"
+ input: "lstm_4/mul_5"
+ input: "lstm_4/mul_6"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "lstm_4/out"
+ op: "Identity"
+ input: "lstm_4/add_2"
+ device: "/device:GPU:*"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+versions {
+ producer: 19
+}
diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py
new file mode 100644
index 0000000000..9ffeb6c2a2
--- /dev/null
+++ b/tensorflow/compiler/tests/lstm_test.py
@@ -0,0 +1,293 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the LSTM cell and layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.compiler.tests import lstm
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import flags as flags_lib
+from tensorflow.python.platform import test
+
+flags = flags_lib
+FLAGS = flags.FLAGS
+
+flags.DEFINE_integer('batch_size', 128,
+ 'Inputs are fed in batches of this size, for both '
+ 'inference and training. Larger values cause the matmul '
+ 'in each LSTM cell to have higher dimensionality.')
+flags.DEFINE_integer('seq_length', 60,
+ 'Length of the unrolled sequence of LSTM cells in a layer.'
+ 'Larger values cause more LSTM matmuls to be run.')
+flags.DEFINE_integer('num_inputs', 1024,
+ 'Dimension of inputs that are fed into each LSTM cell.')
+flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
+flags.DEFINE_string('device', 'gpu',
+ 'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
+ 'For details see documentation for tf.Graph.device.')
+
+flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
+ '*.pbtxt format to this directory.')
+
+
+def _DumpGraph(graph, basename):
+ if FLAGS.dump_graph_dir:
+ name = os.path.join(FLAGS.dump_graph_dir, basename + '.pbtxt')
+ with open(name, 'w') as f:
+ f.write(str(graph.as_graph_def()))
+
+
+def _Sigmoid(x):
+ return 1. / (1. + np.exp(-x))
+
+
+def _Clip(x):
+ return np.maximum(np.minimum(x, 1.), -1.)
+
+
+class LSTMTest(test.TestCase):
+
+ def setUp(self):
+ # The tests for a single LSTM cell and LSTM layer use these values as
+ # inputs. We always set the dimensionality of num_inputs=1; thus batch_size
+ # actually represents the different input cases.
+ self._inputs = np.array([[-1.], [-.5], [0.], [.5], [1.]], np.float32)
+ self._batch_size = len(self._inputs)
+
+ def _NextC(self, inputs, weight, m_prev, c_prev):
+ """Returns the next c states of an LSTM cell."""
+ x = (inputs + m_prev) * weight
+ return _Clip(_Clip(_Sigmoid(x) * c_prev) + _Clip(_Sigmoid(x) * np.tanh(x)))
+
+ def _NextM(self, inputs, weight, m_prev, c_prev):
+ """Returns the next m states of an LSTM cell."""
+ x = (inputs + m_prev) * weight
+ return _Clip(_Sigmoid(x) * self._NextC(inputs, weight, m_prev, c_prev))
+
+ def _RunLSTMCell(self, basename, init_weights, m_prev_scalar, c_prev_scalar,
+ pad_scalar):
+ with self.test_session() as sess:
+ num_inputs = 1
+ num_nodes = 1
+
+ weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
+ m_prev = constant_op.constant([[m_prev_scalar]] * self._batch_size)
+ c_prev = constant_op.constant([[c_prev_scalar]] * self._batch_size)
+ x = constant_op.constant(self._inputs)
+ pad = constant_op.constant([[pad_scalar]] * self._batch_size)
+
+ m, c = lstm.LSTMCell(weights, m_prev, c_prev, x, pad)
+ _DumpGraph(sess.graph, 'lstm_cell_%s_%d_%d_%d' %
+ (basename, m_prev_scalar, c_prev_scalar, pad_scalar))
+
+ # Initialize variables and run the unrolled LSTM step.
+ sess.run(variables.global_variables_initializer())
+ return sess.run([m, c])
+
+ def testLSTMCell(self):
+ # Run with all-0 weights, no padding.
+ m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
+ self.assertAllClose(m, [[0.]] * self._batch_size)
+ self.assertAllClose(c, [[0.]] * self._batch_size)
+ m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
+ self.assertAllClose(m, [[.25]] * self._batch_size)
+ self.assertAllClose(c, [[.5]] * self._batch_size)
+ m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
+ self.assertAllClose(m, [[.0]] * self._batch_size)
+ self.assertAllClose(c, [[.0]] * self._batch_size)
+ m, c = self._RunLSTMCell('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
+ self.assertAllClose(m, [[.25]] * self._batch_size)
+ self.assertAllClose(c, [[.5]] * self._batch_size)
+
+ # Run with all-1 weights, no padding.
+ for m_prev in [0., 1.]:
+ for c_prev in [0., 1.]:
+ m, c = self._RunLSTMCell('ones',
+ init_ops.ones_initializer(), m_prev, c_prev,
+ 0.)
+ self.assertAllClose(m, self._NextM(self._inputs, 1., m_prev, c_prev))
+ self.assertAllClose(c, self._NextC(self._inputs, 1., m_prev, c_prev))
+
+ # Run with random weights.
+ for weight in np.random.rand(3):
+ weight_tf = constant_op.constant(weight, dtypes.float32)
+ random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)
+
+ # No padding.
+ for m_prev in [0., 1.]:
+ for c_prev in [0., 1.]:
+ m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 0.)
+ self.assertAllClose(m,
+ self._NextM(self._inputs, weight, m_prev, c_prev))
+ self.assertAllClose(c,
+ self._NextC(self._inputs, weight, m_prev, c_prev))
+
+ # Set padding.
+ for m_prev in [0., 1.]:
+ for c_prev in [0., 1.]:
+ m, c = self._RunLSTMCell('random', random_weight, m_prev, c_prev, 1.)
+ self.assertAllClose(m, [[m_prev]] * self._batch_size)
+ self.assertAllClose(c, [[c_prev]] * self._batch_size)
+
+ def testLSTMLayerErrors(self):
+ num_inputs = 1
+ num_nodes = 1
+ seq_length = 3
+
+ weights = array_ops.zeros(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
+ m = constant_op.constant([[0.]] * self._batch_size)
+ c = constant_op.constant([[0.]] * self._batch_size)
+ x_seq = [constant_op.constant(self._inputs)] * seq_length
+ pad = constant_op.constant([[0.]] * self._batch_size)
+
+ with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
+ lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad])
+ with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
+ lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 2)
+ with self.assertRaisesWithPredicateMatch(ValueError, 'length of x_seq'):
+ lstm.LSTMLayer('lstm', weights, m, c, x_seq, [pad] * 4)
+
+ def _RunLSTMLayer(self, basename, init_weights, m_init_scalar, c_init_scalar,
+ pad_scalar):
+ with self.test_session() as sess:
+ num_inputs = 1
+ num_nodes = 1
+ seq_length = 3
+
+ weights = init_weights(lstm.LSTMCellWeightsShape(num_inputs, num_nodes))
+ m_init = constant_op.constant([[m_init_scalar]] * self._batch_size)
+ c_init = constant_op.constant([[c_init_scalar]] * self._batch_size)
+ x_seq = [constant_op.constant(self._inputs)] * seq_length
+ pad_seq = [constant_op.constant([[pad_scalar]] * self._batch_size)
+ ] * seq_length
+
+ out_seq = lstm.LSTMLayer('lstm', weights, m_init, c_init, x_seq, pad_seq)
+ _DumpGraph(sess.graph, 'lstm_layer_%s_%d_%d_%d' %
+ (basename, m_init_scalar, c_init_scalar, pad_scalar))
+
+ # Initialize variables and run the unrolled LSTM layer.
+ sess.run(variables.global_variables_initializer())
+ return sess.run(out_seq)
+
+ def testLSTMLayer(self):
+ # Run with all-0 weights, no padding.
+ o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 0., 0.)
+ self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
+ o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 0., 1., 0.)
+ self.assertAllClose(o, [[[.25]] * self._batch_size,
+ [[.125]] * self._batch_size,
+ [[.0625]] * self._batch_size])
+ o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 0., 0.)
+ self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
+ o = self._RunLSTMLayer('zeros', init_ops.zeros_initializer(), 1., 1., 0.)
+ self.assertAllClose(o, [[[.25]] * self._batch_size,
+ [[.125]] * self._batch_size,
+ [[.0625]] * self._batch_size])
+
+ # Run with all-1 weights, no padding.
+ weight1 = 1.
+ for m_init in [0., 1.]:
+ for c_init in [0., 1.]:
+ o = self._RunLSTMLayer('ones',
+ init_ops.ones_initializer(), m_init, c_init, 0.)
+ m0 = self._NextM(self._inputs, weight1, m_init, c_init)
+ c0 = self._NextC(self._inputs, weight1, m_init, c_init)
+ self.assertAllClose(o[0], m0)
+ m1 = self._NextM(self._inputs, weight1, m0, c0)
+ c1 = self._NextC(self._inputs, weight1, m0, c0)
+ self.assertAllClose(o[1], m1)
+ m2 = self._NextM(self._inputs, weight1, m1, c1)
+ self.assertAllClose(o[2], m2)
+
+ # Run with random weights.
+ for weight in np.random.rand(3):
+ weight_tf = constant_op.constant(weight, dtypes.float32)
+ random_weight = lambda shape, w=weight_tf: array_ops.fill(shape, w)
+
+ # No padding.
+ for m_init in [0., 1.]:
+ for c_init in [0., 1.]:
+ o = self._RunLSTMLayer('random', random_weight, m_init, c_init, 0.)
+ m0 = self._NextM(self._inputs, weight, m_init, c_init)
+ c0 = self._NextC(self._inputs, weight, m_init, c_init)
+ self.assertAllClose(o[0], m0)
+ m1 = self._NextM(self._inputs, weight, m0, c0)
+ c1 = self._NextC(self._inputs, weight, m0, c0)
+ self.assertAllClose(o[1], m1)
+ m2 = self._NextM(self._inputs, weight, m1, c1)
+ self.assertAllClose(o[2], m2)
+
+ # Set padding.
+ o = self._RunLSTMLayer('random', random_weight, 0., 0., 1.)
+ self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
+ o = self._RunLSTMLayer('random', random_weight, 0., 1., 1.)
+ self.assertAllClose(o, [[[0.]] * self._batch_size] * 3)
+ o = self._RunLSTMLayer('random', random_weight, 1., 0., 1.)
+ self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
+ o = self._RunLSTMLayer('random', random_weight, 1., 1., 1.)
+ self.assertAllClose(o, [[[1.]] * self._batch_size] * 3)
+
+
+class LSTMBenchmark(test.Benchmark):
+ """Mcro-benchmarks for a single layer of LSTM cells."""
+
+ def _LayerBuilder(self, do_training):
+ out_seq, weights = lstm.BuildLSTMLayer(FLAGS.batch_size, FLAGS.seq_length,
+ FLAGS.num_inputs, FLAGS.num_nodes)
+ name, fetches = ('lstm_layer_inference', out_seq)
+ if do_training:
+ # Not a real loss function, but good enough for benchmarking backprop.
+ loss = math_ops.reduce_sum(math_ops.add_n(out_seq))
+ dw = gradients_impl.gradients(loss, weights)
+ name, fetches = ('lstm_layer_training', dw)
+
+ _DumpGraph(ops.get_default_graph(),
+ '%s_%d_%d_%d_%d' % (name, FLAGS.batch_size, FLAGS.seq_length,
+ FLAGS.num_inputs, FLAGS.num_nodes))
+ return name, fetches
+
+ def benchmarkLayerInference(self):
+ xla_test.Benchmark(self, lambda: self._LayerBuilder(False), False,
+ FLAGS.device)
+
+ def benchmarkLayerInferenceXLA(self):
+ xla_test.Benchmark(self, lambda: self._LayerBuilder(False), True,
+ FLAGS.device)
+
+ def benchmarkLayerTraining(self):
+ xla_test.Benchmark(self, lambda: self._LayerBuilder(True), False,
+ FLAGS.device)
+
+ def benchmarkLayerTrainingXLA(self):
+ xla_test.Benchmark(self, lambda: self._LayerBuilder(True), True,
+ FLAGS.device)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py
new file mode 100644
index 0000000000..566c02e72f
--- /dev/null
+++ b/tensorflow/compiler/tests/nary_ops_test.py
@@ -0,0 +1,209 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for operators with > 3 or arbitrary numbers of arguments."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class NAryOpsTest(XLATestCase):
+
+ def _testNAry(self, op, args, expected):
+ with self.test_session() as session:
+ with self.test_scope():
+ placeholders = [
+ array_ops.placeholder(dtypes.as_dtype(arg.dtype), arg.shape)
+ for arg in args
+ ]
+ feeds = {placeholders[i]: args[i] for i in range(0, len(args))}
+ output = op(placeholders)
+ result = session.run(output, feeds)
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testFloat(self):
+ self._testNAry(math_ops.add_n,
+ [np.array([[1, 2, 3]], dtype=np.float32)],
+ expected=np.array([[1, 2, 3]], dtype=np.float32))
+
+ self._testNAry(math_ops.add_n,
+ [np.array([1, 2], dtype=np.float32),
+ np.array([10, 20], dtype=np.float32)],
+ expected=np.array([11, 22], dtype=np.float32))
+ self._testNAry(math_ops.add_n,
+ [np.array([-4], dtype=np.float32),
+ np.array([10], dtype=np.float32),
+ np.array([42], dtype=np.float32)],
+ expected=np.array([48], dtype=np.float32))
+
+ def testConcat(self):
+ self._testNAry(
+ lambda x: array_ops.concat_v2(x, 0), [
+ np.array(
+ [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
+ [[7, 8, 9], [10, 11, 12]], dtype=np.float32)
+ ],
+ expected=np.array(
+ [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.float32))
+
+ self._testNAry(
+ lambda x: array_ops.concat_v2(x, 1), [
+ np.array(
+ [[1, 2, 3], [4, 5, 6]], dtype=np.float32), np.array(
+ [[7, 8, 9], [10, 11, 12]], dtype=np.float32)
+ ],
+ expected=np.array(
+ [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], dtype=np.float32))
+
+ def testSplitV(self):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = session.run(
+ array_ops.split(np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2]],
+ dtype=np.float32),
+ [2, 2], 1))
+ expected = [np.array([[1, 2], [5, 6], [9, 0]], dtype=np.float32),
+ np.array([[3, 4], [7, 8], [1, 2]], dtype=np.float32)]
+ self.assertAllEqual(output, expected)
+
+ def testStridedSlice(self):
+ self._testNAry(lambda x: array_ops.strided_slice(*x),
+ [np.array([[], [], []], dtype=np.float32),
+ np.array([1, 0], dtype=np.int32),
+ np.array([3, 0], dtype=np.int32),
+ np.array([1, 1], dtype=np.int32)],
+ expected=np.array([[], []], dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice(*x),
+ [np.array([[], [], []], dtype=np.float32),
+ np.array([1, 0], dtype=np.int64),
+ np.array([3, 0], dtype=np.int64),
+ np.array([1, 1], dtype=np.int64)],
+ expected=np.array([[], []], dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice(*x),
+ [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ dtype=np.float32),
+ np.array([1, 1], dtype=np.int32),
+ np.array([3, 3], dtype=np.int32),
+ np.array([1, 1], dtype=np.int32)],
+ expected=np.array([[5, 6], [8, 9]], dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice(*x),
+ [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ dtype=np.float32),
+ np.array([0, 2], dtype=np.int32),
+ np.array([2, 0], dtype=np.int32),
+ np.array([1, -1], dtype=np.int32)],
+ expected=np.array([[3, 2], [6, 5]], dtype=np.float32))
+
+ self._testNAry(lambda x: x[0][0:2, array_ops.newaxis, ::-1],
+ [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ dtype=np.float32)],
+ expected=np.array([[[3, 2, 1]], [[6, 5, 4]]],
+ dtype=np.float32))
+
+ self._testNAry(lambda x: x[0][1, :, array_ops.newaxis],
+ [np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
+ dtype=np.float32)],
+ expected=np.array([[4], [5], [6]], dtype=np.float32))
+
+ def testStridedSliceGrad(self):
+ # Tests cases where input shape is empty.
+ self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
+ [np.array([], dtype=np.int32),
+ np.array([], dtype=np.int32),
+ np.array([], dtype=np.int32),
+ np.array([], dtype=np.int32),
+ np.float32(0.5)],
+ expected=np.array(np.float32(0.5), dtype=np.float32))
+
+ # Tests case where input shape is non-empty, but gradients are empty.
+ self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
+ [np.array([3], dtype=np.int32),
+ np.array([0], dtype=np.int32),
+ np.array([0], dtype=np.int32),
+ np.array([1], dtype=np.int32),
+ np.array([], dtype=np.float32)],
+ expected=np.array([0, 0, 0], dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
+ [np.array([3, 0], dtype=np.int32),
+ np.array([1, 0], dtype=np.int32),
+ np.array([3, 0], dtype=np.int32),
+ np.array([1, 1], dtype=np.int32),
+ np.array([[], []], dtype=np.float32)],
+ expected=np.array([[], [], []], dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
+ [np.array([3, 3], dtype=np.int32),
+ np.array([1, 1], dtype=np.int32),
+ np.array([3, 3], dtype=np.int32),
+ np.array([1, 1], dtype=np.int32),
+ np.array([[5, 6], [8, 9]], dtype=np.float32)],
+ expected=np.array([[0, 0, 0], [0, 5, 6], [0, 8, 9]],
+ dtype=np.float32))
+
+ def ssg_test(x):
+ return array_ops.strided_slice_grad(*x, shrink_axis_mask=0x4,
+ new_axis_mask=0x1)
+
+ self._testNAry(ssg_test,
+ [np.array([3, 1, 3], dtype=np.int32),
+ np.array([0, 0, 0, 2], dtype=np.int32),
+ np.array([0, 3, 1, -4], dtype=np.int32),
+ np.array([1, 2, 1, -3], dtype=np.int32),
+ np.array([[[1], [2]]], dtype=np.float32)],
+ expected=np.array([[[0, 0, 1]], [[0, 0, 0]], [[0, 0, 2]]],
+ dtype=np.float32))
+
+ ssg_test2 = lambda x: array_ops.strided_slice_grad(*x, new_axis_mask=0x15)
+ self._testNAry(ssg_test2,
+ [np.array([4, 4], dtype=np.int32),
+ np.array([0, 0, 0, 1, 0], dtype=np.int32),
+ np.array([0, 3, 0, 4, 0], dtype=np.int32),
+ np.array([1, 2, 1, 2, 1], dtype=np.int32),
+ np.array([[[[[1], [2]]], [[[3], [4]]]]], dtype=np.float32)],
+ expected=np.array([[0, 1, 0, 2], [0, 0, 0, 0], [0, 3, 0, 4],
+ [0, 0, 0, 0]], dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
+ [np.array([3, 3], dtype=np.int32),
+ np.array([0, 2], dtype=np.int32),
+ np.array([2, 0], dtype=np.int32),
+ np.array([1, -1], dtype=np.int32),
+ np.array([[1, 2], [3, 4]], dtype=np.float32)],
+ expected=np.array([[0, 2, 1], [0, 4, 3], [0, 0, 0]],
+ dtype=np.float32))
+
+ self._testNAry(lambda x: array_ops.strided_slice_grad(*x),
+ [np.array([3, 3], dtype=np.int32),
+ np.array([2, 2], dtype=np.int32),
+ np.array([0, 1], dtype=np.int32),
+ np.array([-1, -2], dtype=np.int32),
+ np.array([[1], [2]], dtype=np.float32)],
+ expected=np.array([[0, 0, 0], [0, 0, 2], [0, 0, 1]],
+ dtype=np.float32))
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
new file mode 100644
index 0000000000..6f588d8ab5
--- /dev/null
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -0,0 +1,61 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for operators with no arguments."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import googletest
+
+
+class NullaryOpsTest(XLATestCase):
+
+ def _testNullary(self, op, expected):
+ with self.test_session() as session:
+ with self.test_scope():
+ output = op()
+ result = session.run(output)
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testNoOp(self):
+ with self.test_session():
+ with self.test_scope():
+ output = control_flow_ops.no_op()
+ # This should not crash.
+ output.run()
+
+ def testConstants(self):
+ constants = [
+ np.float32(42),
+ np.array([], dtype=np.float32),
+ np.array([1, 2], dtype=np.float32),
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
+ np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
+ dtype=np.float32),
+ np.array([[[]], [[]]], dtype=np.float32),
+ np.array([[[[1]]]], dtype=np.float32),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/pooling_ops_test.py b/tensorflow/compiler/tests/pooling_ops_test.py
new file mode 100644
index 0000000000..52290e6354
--- /dev/null
+++ b/tensorflow/compiler/tests/pooling_ops_test.py
@@ -0,0 +1,511 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for pooling operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import googletest
+
+
+def NHWCToNCHW(input_tensor):
+ """Convert the input from NHWC format to NCHW.
+
+ Args:
+ input_tensor: a 4-D tensor, or a 4-element array representing the same.
+
+ Returns:
+ the converted tensor or a shape array
+ """
+ if isinstance(input_tensor, ops.Tensor):
+ return array_ops.transpose(input_tensor, [0, 3, 1, 2])
+ else:
+ return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]]
+
+
+def NCHWToNHWC(input_tensor):
+ """Convert the input from NCHW format to NHWC.
+
+ Args:
+ input_tensor: a 4-D tensor, or a 4-element array representing the same.
+
+ Returns:
+ the converted tensor or a shape array
+ """
+ if isinstance(input_tensor, ops.Tensor):
+ return array_ops.transpose(input_tensor, [0, 2, 3, 1])
+ else:
+ return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]]
+
+
+def GetTestConfigs():
+ """Get all the valid tests configs to run.
+
+ Returns:
+ all the valid test configs
+ """
+ test_configs = ["NHWC", "NCHW"]
+ return test_configs
+
+
+class PoolingTest(XLATestCase):
+
+ def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
+ data_format, expected):
+ """Verifies the output values of the pooling function.
+
+ Args:
+ pool_func: Function to be called, currently only co.MaxPool.
+ input_sizes: Input tensor dimensions.
+ ksize: The kernel size dimensions
+ strides: The stride dimensions
+ padding: Padding type.
+ data_format: The data format we use to run the pooling operation.
+ expected: An array containing the expected operation outputs.
+ """
+ total_size = np.prod(input_sizes)
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
+ x = x.reshape(input_sizes)
+ with self.test_session() as sess:
+ with self.test_scope():
+ inputs = array_ops.placeholder(dtypes.float32)
+ t = inputs
+ if data_format == "NCHW":
+ t = NHWCToNCHW(t)
+ ksize = NHWCToNCHW(ksize)
+ strides = NHWCToNCHW(strides)
+ t = pool_func(t,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
+ if data_format == "NCHW":
+ t = NCHWToNHWC(t)
+ actual = sess.run(t, {inputs: x})
+ self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6)
+
+ def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
+ expected):
+ """Verifies the output values of the pooling function.
+
+ Args:
+ pool_func: Function to be called, co.MaxPool, co.AvgPool,
+ or the Lua version.
+ input_sizes: Input tensor dimensions.
+ ksize: The kernel size dimensions
+ strides: The stride dimensions
+ padding: Padding type.
+ expected: An array containing the expected operation outputs.
+ """
+ for data_format in GetTestConfigs():
+ self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
+ data_format, expected)
+
+ def testMaxPoolValidPadding(self):
+ expected_output = [13.0, 14.0, 15.0]
+ self._VerifyValues(nn_ops.max_pool,
+ input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="VALID",
+ expected=expected_output)
+
+ def testMaxPoolSamePadding(self):
+ expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
+ self._VerifyValues(nn_ops.max_pool,
+ input_sizes=[1, 2, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output)
+
+ def testMaxPoolSamePaddingNonSquareWindow(self):
+ # input is:
+ # [1.0, 2.0
+ # 3.0 4.0]
+ #
+ # Window of [x, x] should do:
+ #
+ # [max(1.0, 2.0), max(2.0, padded0),
+ # max(3.0, 4.0), max(4.0, padded0)]
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 2, 2, 1],
+ ksize=[1, 1, 2, 1],
+ strides=[1, 1, 1, 1],
+ padding="SAME",
+ expected=[2.0, 2.0, 4.0, 4.0])
+
+ def testMaxPoolValidPaddingUnevenStride(self):
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 1, 2, 1],
+ padding="VALID",
+ expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0])
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 1, 1],
+ padding="VALID",
+ expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0])
+
+ def testMaxPoolSamePaddingFilter4(self):
+ expected_output = [
+ 21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
+ 61.0, 62.0, 63.0, 64.0
+ ]
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 4, 4, 4],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output)
+
+ def testMaxPoolSamePaddingFilter8(self):
+ expected_output = [
+ 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
+ 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0,
+ 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0,
+ 191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
+ 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0,
+ 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0,
+ 317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0,
+ 407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
+ 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0,
+ 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0,
+ 469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0,
+ 487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
+ 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0
+ ]
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 8, 8, 8],
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output)
+
+ # Tests for DepthwiseMaxPooling on CPU only.
+ def testDepthwiseMaxPool1x1DepthWindow1(self):
+ # input is:
+ # [1.0, ..., 10.0] along depth,
+ #
+ # We maxpool by depth in patches of 2.
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 1, 1, 10],
+ ksize=[1, 1, 1, 2],
+ strides=[1, 1, 1, 2],
+ padding="SAME",
+ expected=[2.0, 4.0, 6.0, 8.0, 10.0])
+
+ def testDepthwiseMaxPool2x2DepthWindow3(self):
+ # input is:
+ #
+ # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
+ # output. Each node has contiguous values, so the depthwise max
+ # should be multiples of 3.0.
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 2, 2, 6],
+ ksize=[1, 1, 1, 3],
+ strides=[1, 1, 1, 3],
+ padding="SAME",
+ expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0])
+
+ def testKernelSmallerThanStrideValid(self):
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 7, 7, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 3, 3, 1],
+ padding="VALID",
+ expected=[9, 12, 30, 33])
+
+ def testKernelSmallerThanStrideSame(self):
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 3, 3, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=[1, 3, 7, 9])
+
+ self._VerifyValues(
+ nn_ops.max_pool,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=[1, 3, 9, 11])
+
+ # Average pooling
+ def testAvgPoolValidPadding(self):
+ expected_output = [7, 8, 9]
+ self._VerifyValues(
+ nn_ops.avg_pool,
+ input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="VALID",
+ expected=expected_output)
+
+ def testAvgPoolSamePadding(self):
+ expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
+ self._VerifyValues(
+ nn_ops.avg_pool,
+ input_sizes=[1, 2, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=expected_output)
+
+
+class PoolGradTest(XLATestCase):
+
+ CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
+
+ def _VerifyOneTest(self, pool_func, pool_grad_func, input_sizes, ksize,
+ strides, padding, data_format):
+ """Verifies the output values of the pooling gradient function.
+
+ Args:
+ pool_func: Forward pooling function
+ pool_grad_func: Pooling gradient function for pool_grad_func
+ input_sizes: Input tensor dimensions.
+ ksize: The kernel size dimensions
+ strides: The stride dimensions
+ padding: Padding type.
+ data_format: The data format we use to run the pooling operation.
+ """
+ total_size = np.prod(input_sizes)
+ x = np.arange(1, total_size + 1, dtype=np.float32).reshape(input_sizes)
+ with self.test_session() as sess:
+ # Use the forward pool function to compute some corresponding outputs
+ # (needed for the CPU device, and we need the shape in both cases).
+ with ops.device(self.CPU_DEVICE):
+ inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes)
+ outputs = pool_func(
+ inputs,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format="NHWC")
+
+ output_vals = np.array(sess.run(outputs, {inputs: x}))
+ output_gradient_vals = np.arange(
+ 1, output_vals.size + 1, dtype=np.float32)
+ output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
+
+ # Use the Tensorflow CPU pooling gradient to compute the expected input
+ # gradients.
+ with ops.device(self.CPU_DEVICE):
+ output_gradients = array_ops.placeholder(
+ dtypes.float32, shape=output_vals.shape)
+ expected_input_gradients = pool_grad_func(
+ inputs,
+ outputs,
+ output_gradients,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format="NHWC")
+ expected_input_gradient_vals = sess.run(
+ expected_input_gradients,
+ {inputs: x,
+ output_gradients: output_gradient_vals})
+
+ # Run the gradient op on the XLA device
+ with self.test_scope():
+ outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
+ xla_inputs = inputs
+ xla_outputs = outputs
+ xla_output_gradients = output_gradients
+ xla_ksize = ksize
+ xla_strides = strides
+ if data_format == "NCHW":
+ xla_inputs = NHWCToNCHW(inputs)
+ xla_outputs = NHWCToNCHW(outputs)
+ xla_output_gradients = NHWCToNCHW(output_gradients)
+ xla_ksize = NHWCToNCHW(ksize)
+ xla_strides = NHWCToNCHW(strides)
+ actual_input_gradients = pool_grad_func(
+ xla_inputs,
+ xla_outputs,
+ xla_output_gradients,
+ ksize=xla_ksize,
+ strides=xla_strides,
+ padding=padding,
+ data_format=data_format)
+ if data_format == "NCHW":
+ actual_input_gradients = NCHWToNHWC(actual_input_gradients)
+ actual = sess.run(actual_input_gradients, {
+ inputs: x,
+ outputs: output_vals,
+ output_gradients: output_gradient_vals
+ })
+
+ # Compare the Tensorflow and XLA results.
+ self.assertAllClose(
+ expected_input_gradient_vals.flatten(),
+ actual.flatten(),
+ rtol=1e-5,
+ atol=1e-6)
+ self.assertShapeEqual(actual, inputs)
+
+ def _VerifyValues(self, pool_func, pool_grad_func, input_sizes, ksize,
+ strides, padding):
+ """Verifies the output values of the pooling function.
+
+ Args:
+ pool_func: Pooling function to be called, e.g., tf.nn.max_pool
+ pool_grad_func: Corresponding pooling gradient function.
+ input_sizes: Input tensor dimensions.
+ ksize: The kernel size dimensions
+ strides: The stride dimensions
+ padding: Padding type.
+ """
+ for data_format in GetTestConfigs():
+ self._VerifyOneTest(pool_func, pool_grad_func, input_sizes, ksize,
+ strides, padding, data_format)
+
+ def _TestPooling(self, forward_op, backward_op):
+ # VALID padding
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 3, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="VALID")
+
+ # SAME padding
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 2, 3, 3],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME")
+
+ # SAME padding, non square window
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 2, 2, 1],
+ ksize=[1, 1, 2, 1],
+ strides=[1, 1, 1, 1],
+ padding="SAME")
+
+ # VALID padding, uneven stride
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 1, 2, 1],
+ padding="VALID")
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 1, 1],
+ padding="VALID")
+
+ # SAME padding, size 4 input
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 4, 4, 4],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME")
+
+ # SAME padding, size 8 input
+ self._VerifyValues(
+ forward_op,
+ backward_op,
+ input_sizes=[1, 8, 8, 8],
+ ksize=[1, 3, 3, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME")
+
+ def testMaxPool(self):
+ self._TestPooling(nn_ops.max_pool, gen_nn_ops._max_pool_grad)
+
+ def testAvgPool(self):
+ # Wrapper around AvgPoolGrad that ignores extra arguments needed by
+ # MaxPoolGrad.
+ def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,
+ data_format):
+ del outputs # Unused by average-pooling gradients.
+ return gen_nn_ops._avg_pool_grad(
+ inputs.get_shape().as_list(),
+ output_gradients,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format=data_format)
+
+ self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)
+
+ # The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
+ # the stride size, so we only run the following tests on MaxPoolGrad.
+
+ def testMaxPoolKernelSmallerThanStrideValid(self):
+ self._VerifyValues(
+ nn_ops.max_pool,
+ gen_nn_ops._max_pool_grad,
+ input_sizes=[1, 7, 7, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 3, 3, 1],
+ padding="VALID")
+
+ def testMaxPoolKernelSmallerThanStrideSame(self):
+ self._VerifyValues(
+ nn_ops.max_pool,
+ gen_nn_ops._max_pool_grad,
+ input_sizes=[1, 3, 3, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME")
+
+ self._VerifyValues(
+ nn_ops.max_pool,
+ gen_nn_ops._max_pool_grad,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME")
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
new file mode 100644
index 0000000000..41403858a6
--- /dev/null
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -0,0 +1,2097 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Randomized tests for XLA implementations of Tensorflow operations.
+//
+// For each operator, the tests in this file choose a set of random inputs and
+// attributes. The test then compares the outputs of the operator when executed
+// via Tensorflow using the CPU device and when executed via XLA.
+//
+// By default, each test chooses a random seed nondeterministically (using
+// std::random_device). However, a particular choice of random seed can be
+// forced using the flag --tf_xla_random_seed; each test logs the
+// flag value necessary to reproduce its outputs.
+//
+// Example usage:
+// Run tests, comparing the Tensorflow CPU operators with their XLA-compiled
+// counterparts:
+// randomized_tests \
+// --tf_xla_test_use_jit=true --tf_xla_test_device=CPU \
+// --tf_xla_test_repetitions=20
+
+// TODO(phawkins): add tests for:
+// * ArgMax
+// * DepthwiseConv2DNative
+// * Gather
+// * InvertPermutation
+// * MaxPoolGrad (requires implementation of forward operator)
+// * Select
+// * Unpack
+//
+// TODO(phawkins): improve tests for:
+// * StridedSliceGrad (need to use shape function to compute sensible inputs)
+
+#include <random>
+#include <unordered_map>
+
+#include "tensorflow/compiler/jit/defs.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace {
+
+// Command line flags: see main() below.
+int64 tf_xla_random_seed = 0;
+int32 tf_xla_test_repetitions = 20;
+string* tf_xla_test_device_ptr; // initial value set in main()
+bool tf_xla_test_use_jit = true;
+
+string DeviceTypeToDeviceName(DeviceType type) {
+ return strings::StrCat("/job:localhost/replica:0/task:0/device:", type.type(),
+ ":0");
+}
+
+constexpr std::array<DataType, 3> kAllXlaTypes = {
+ {DT_INT32, DT_FLOAT, DT_BOOL}};
+
+// An OpTestBuilder is a graph builder class that takes as input an operator to
+// test, its inputs and attributes, and builds a graph that executes the
+// operator.
+class OpTestBuilder {
+ public:
+ explicit OpTestBuilder(const string& op_name);
+
+ // Adds an input 'tensor'.
+ OpTestBuilder& Input(Tensor tensor);
+
+ // Sets an attribute.
+ template <class T>
+ OpTestBuilder& Attr(StringPiece attr_name, T&& value);
+
+ // Overload needed to allow {...} expressions for value.
+ template <class T>
+ OpTestBuilder& Attr(StringPiece attr_name, std::initializer_list<T> value);
+
+ // Adds nodes that executes the operator under test on 'device' to 'graphdef'.
+ // If 'use_jit' is true, marks the operator under test to be compiled by XLA.
+ // The graph will consist of one Placeholder node per input, the operator
+ // itself, and one Identity node per output. If 'test_node_def' is not null,
+ // sets it to the NodeDef of the operator under test. Fills 'inputs' and
+ // 'outputs' with the names of the input placeholder nodes and the output
+ // identity nodes, respectively.
+ Status BuildGraph(string name_prefix, string device, bool use_jit,
+ GraphDef* graphdef, NodeDef** test_node_def,
+ std::vector<string>* inputs,
+ std::vector<string>* outputs) const;
+
+ const std::vector<Tensor>& inputs() const { return inputs_; }
+
+ private:
+ NodeDef node_def_;
+ std::vector<Tensor> inputs_;
+};
+
+OpTestBuilder::OpTestBuilder(const string& op_name) {
+ node_def_.set_op(op_name);
+}
+
+OpTestBuilder& OpTestBuilder::Input(Tensor tensor) {
+ VLOG(1) << "Adding input: " << tensor.DebugString();
+ inputs_.push_back(tensor);
+ return *this;
+}
+
+template <class T>
+OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name, T&& value) {
+ AddNodeAttr(attr_name, std::forward<T>(value), &node_def_);
+ return *this;
+}
+
+template <class T>
+OpTestBuilder& OpTestBuilder::Attr(StringPiece attr_name,
+ std::initializer_list<T> value) {
+ Attr<std::initializer_list<T>>(attr_name, std::move(value));
+ return *this;
+}
+
+Status OpTestBuilder::BuildGraph(string name_prefix, string device,
+ bool use_jit, GraphDef* graphdef,
+ NodeDef** test_node_def,
+ std::vector<string>* inputs,
+ std::vector<string>* outputs) const {
+ OpRegistryInterface* op_registry = OpRegistry::Global();
+
+ const OpDef* op_def;
+ TF_RETURN_IF_ERROR(op_registry->LookUpOpDef(node_def_.op(), &op_def));
+
+ NodeDef* test_def = graphdef->add_node();
+ *test_def = node_def_;
+ test_def->set_name(strings::StrCat(name_prefix, "_op_under_test"));
+ test_def->set_device(device);
+ AddDefaultsToNodeDef(*op_def, test_def);
+ if (use_jit) {
+ AddNodeAttr(kXlaCompileAttr, true, test_def);
+ }
+ VLOG(1) << "Op under test: " << test_def->DebugString();
+
+ DataTypeVector input_types, output_types;
+ TF_RETURN_IF_ERROR(
+ InOutTypesForNode(*test_def, *op_def, &input_types, &output_types));
+
+ // Build feed and fetch nodes.
+ for (int i = 0; i < input_types.size(); ++i) {
+ NodeDef* def = graphdef->add_node();
+ string name = strings::StrCat(name_prefix, "_input_", i);
+ TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Placeholder")
+ .Device(device)
+ .Attr("dtype", input_types[i])
+ .Finalize(def));
+ inputs->push_back(name);
+ test_def->add_input(name);
+ }
+
+ for (int i = 0; i < output_types.size(); ++i) {
+ NodeDef* def = graphdef->add_node();
+ string name = strings::StrCat(name_prefix, "_output_", i);
+ TF_RETURN_IF_ERROR(NodeDefBuilder(name, "Identity")
+ .Device(device)
+ .Attr("T", output_types[i])
+ .Input(test_def->name(), i, output_types[i])
+ .Finalize(def));
+ outputs->push_back(name);
+ }
+
+ if (test_node_def) {
+ *test_node_def = test_def;
+ }
+
+ return Status::OK();
+}
+
+// Test fixture. The fixture manages the random number generator and its seed,
+// and has a number of convenience methods for building random Tensors, shapes,
+// etc.
+class OpTest : public ::testing::Test {
+ public:
+ OpTest();
+
+ // Runs 'fn' up to --tf_xla_test_repetitions times, or until a failure occurs;
+ // whichever happens first.
+ void Repeatedly(std::function<void(void)> fn);
+
+ // Select a random element from 'candidates'.
+ template <typename T>
+ T Choose(gtl::ArraySlice<T> candidates);
+
+ static constexpr int kDefaultMaxRank = 5;
+ static constexpr int64 kDefaultMaxDimensionSize = 20LL;
+
+ // Returns a random dimension size.
+ int64 RandomDim(int64 min = 0, int64 max = kDefaultMaxDimensionSize);
+
+ // Returns a random shape. The tensor has rank in the range [min_rank,
+ // max_rank).
+ // Each dimension has size [0, kDefaultMaxDimensionSize].
+ std::vector<int64> RandomDims(int min_rank = 0,
+ int max_rank = kDefaultMaxRank,
+ int64 min_size = 0,
+ int64 max_size = kDefaultMaxDimensionSize);
+
+ // Given a shape 'dims', build a pair of dimensions such that one broadcasts
+ // to the other.
+ std::pair<std::vector<int64>, std::vector<int64>> BroadcastableDims(
+ std::vector<int64> dims);
+
+ // Builds a random pair of broadcastable dims.
+ // TODO(phawkins): currently the maximum rank is 3, because broadcasting > 3
+ // dimensions is unimplemented by the Tensorflow Eigen code (b/29268487)
+ std::pair<std::vector<int64>, std::vector<int64>> BroadcastableDims();
+
+ // Returns a tensor filled with random but "reasonable" values from the middle
+ // of the type's range. If the shape is omitted, a random shape is used.
+ // TODO(phawkins): generalize this code to a caller-supplied distribution.
+ Tensor RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape);
+ Tensor RandomTensor(DataType dtype);
+
+ // Like RandomTensor, but uses values >= 0.
+ Tensor RandomNonNegativeTensor(DataType dtype, gtl::ArraySlice<int64> shape);
+ Tensor RandomNonNegativeTensor(DataType dtype);
+
+ // Returns a random subset of the integers in the range [0, rank), suitable
+ // for use as reduction indices.
+ Tensor RandomReductionIndices(int rank);
+
+ struct WindowedDims {
+ Padding padding;
+ int kernel_rows, kernel_cols;
+ int stride_rows, stride_cols;
+ int input_rows, input_cols;
+ int64 output_rows, output_cols;
+ };
+ // Choose dimensions for a 2D windowed op such as pooling or convolution.
+ // TODO(phawkins): currently this only produces spatial windows, in NHWC
+ // format.
+ WindowedDims ChooseWindowedDims();
+
+ std::mt19937& generator() { return *generator_; }
+
+ // Run the test case described by 'builder' with and without XLA and check
+ // that the outputs are close. Tensors x and y are close if they have the same
+ // type, same shape, and have close values. For floating-point tensors, the
+ // element-wise difference between x and y must no more than
+ // atol + rtol * abs(x); or both elements may be NaN or infinity. For
+ // non-floating-point tensors the element values must match exactly.
+ void ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
+ double atol = 1e-2, double rtol = 1e-2);
+
+ protected:
+ // Per-test state:
+ std::unique_ptr<std::mt19937> generator_;
+
+ std::unique_ptr<Session> session_;
+
+ // Number of test cases built in 'session_'. Used to uniquify node names.
+ int num_tests_ = 0;
+};
+
+OpTest::OpTest() {
+ // Creates a random-number generator for the test case. Use the value of
+ // --tf_xla_random_seed as the seed, if provided.
+ int64 s = tf_xla_random_seed;
+ unsigned int seed;
+ if (s <= 0) {
+ std::random_device random_device;
+ seed = random_device();
+ } else {
+ seed = static_cast<unsigned int>(s);
+ }
+ LOG(INFO) << "Random seed for test case: " << seed
+ << ". To reproduce the "
+ "results of this test, pass flag --tf_xla_random_seed="
+ << seed;
+ generator_.reset(new std::mt19937(seed));
+
+ // Create a session with an empty graph.
+ SessionOptions session_options;
+ session_.reset(NewSession(session_options));
+ GraphDef def;
+ TF_CHECK_OK(session_->Create(def));
+}
+
+void OpTest::Repeatedly(std::function<void(void)> fn) {
+ int const max_repetitions = tf_xla_test_repetitions;
+ for (int i = 0; !HasFailure() && i < max_repetitions; ++i) {
+ fn();
+ }
+}
+
+template <typename T>
+T OpTest::Choose(gtl::ArraySlice<T> candidates) {
+ std::uniform_int_distribution<size_t> d(0, candidates.size() - 1);
+ return candidates[d(generator())];
+}
+
+int64 OpTest::RandomDim(int64 min, int64 max) {
+ std::uniform_int_distribution<int64> size_distribution(min, max - 1);
+ return size_distribution(generator());
+}
+
+std::vector<int64> OpTest::RandomDims(int min_rank, int max_rank,
+ int64 min_size, int64 max_size) {
+ CHECK_LE(0, min_rank);
+ CHECK_LE(min_rank, max_rank);
+ std::uniform_int_distribution<int> rank_distribution(min_rank, max_rank);
+ int rank = rank_distribution(generator());
+ std::vector<int64> dims(rank);
+ std::generate(dims.begin(), dims.end(), [this, min_size, max_size]() {
+ return RandomDim(min_size, max_size);
+ });
+ return dims;
+}
+
+Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice<int64> shape) {
+ Tensor tensor(dtype, TensorShape(shape));
+ switch (dtype) {
+ case DT_FLOAT: {
+ std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+ test::FillFn<float>(&tensor, [this, &distribution](int i) -> float {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_DOUBLE: {
+ std::uniform_real_distribution<double> distribution(-1.0, 1.0);
+ test::FillFn<double>(&tensor, [this, &distribution](int i) -> double {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_INT32: {
+ std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
+ test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_INT64: {
+ std::uniform_int_distribution<int64> distribution(-(1LL << 40),
+ 1LL << 40);
+ test::FillFn<int64>(&tensor, [this, &distribution](int i) -> int64 {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_BOOL: {
+ std::bernoulli_distribution distribution;
+ test::FillFn<bool>(&tensor, [this, &distribution](int i) -> bool {
+ return distribution(generator());
+ });
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unimplemented type " << dtype << " in RandomTensor";
+ }
+ return tensor;
+}
+
+Tensor OpTest::RandomTensor(DataType dtype) {
+ return RandomTensor(dtype, RandomDims());
+}
+
+Tensor OpTest::RandomNonNegativeTensor(DataType dtype,
+ gtl::ArraySlice<int64> shape) {
+ Tensor tensor(dtype, TensorShape(shape));
+ switch (dtype) {
+ case DT_FLOAT: {
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
+ test::FillFn<float>(&tensor, [this, &distribution](int i) -> float {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_DOUBLE: {
+ std::uniform_real_distribution<double> distribution(0.0, 1.0);
+ test::FillFn<double>(&tensor, [this, &distribution](int i) -> double {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_INT32: {
+ std::uniform_int_distribution<int32> distribution(0, 1 << 20);
+ test::FillFn<int32>(&tensor, [this, &distribution](int i) -> int32 {
+ return distribution(generator());
+ });
+ break;
+ }
+ case DT_INT64: {
+ std::uniform_int_distribution<int64> distribution(0, 1LL << 40);
+ test::FillFn<int64>(&tensor, [this, &distribution](int i) -> int64 {
+ return distribution(generator());
+ });
+ break;
+ }
+ default:
+ LOG(FATAL) << "Unimplemented type " << dtype
+ << " in RandomNonNegativeTensor";
+ }
+ return tensor;
+}
+
+Tensor OpTest::RandomNonNegativeTensor(DataType dtype) {
+ return RandomNonNegativeTensor(dtype, RandomDims());
+}
+
+std::pair<std::vector<int64>, std::vector<int64>> OpTest::BroadcastableDims(
+ std::vector<int64> dims) {
+ if (dims.empty()) return {dims, dims};
+
+ // Remove some dimensions from the front of 'dims'.
+ size_t skip =
+ std::uniform_int_distribution<size_t>(0, dims.size() - 1)(generator());
+
+ std::vector<int64> bdims(dims.begin() + skip, dims.end());
+
+ // Randomly replace some of the remaining dimensions of 'dims' with 1.
+ std::bernoulli_distribution random_bool;
+
+ for (int64& dim : bdims) {
+ if (random_bool(generator())) {
+ dim = 1LL;
+ }
+ }
+
+ // Possibly swap the roles of 'dims' and 'bdims'.
+ if (random_bool(generator())) {
+ dims.swap(bdims);
+ }
+ return {dims, bdims};
+}
+
+std::pair<std::vector<int64>, std::vector<int64>> OpTest::BroadcastableDims() {
+ return BroadcastableDims(RandomDims(0, 3));
+}
+
+Tensor OpTest::RandomReductionIndices(int rank) {
+ std::bernoulli_distribution random_bool;
+ std::vector<int32> indices;
+ for (int i = 0; i < rank; ++i) {
+ if (random_bool(generator())) {
+ indices.push_back(i);
+ }
+ }
+ return test::AsTensor<int32>(indices);
+}
+
+OpTest::WindowedDims OpTest::ChooseWindowedDims() {
+ WindowedDims d;
+ d.padding = Choose<Padding>({SAME, VALID});
+ std::uniform_int_distribution<int> random_int(1, 5);
+ Status s;
+ // Repeatedly try different filter/stride sizes until we find a valid
+ // combination.
+ do {
+ // CPU implementations require stride <= kernel size.
+ d.kernel_rows = random_int(generator()),
+ d.input_rows = RandomDim(d.kernel_rows);
+ d.stride_rows =
+ std::uniform_int_distribution<int>(1, d.kernel_rows)(generator());
+ int64 pad_dummy;
+ s = GetWindowedOutputSize(d.input_rows, d.kernel_rows, d.stride_rows,
+ d.padding, &d.output_rows, &pad_dummy);
+ } while (!s.ok());
+ do {
+ d.kernel_cols = random_int(generator());
+ d.input_cols = RandomDim(d.kernel_cols);
+ d.stride_cols =
+ std::uniform_int_distribution<int>(1, d.kernel_cols)(generator());
+ int64 pad_dummy;
+ s = GetWindowedOutputSize(d.input_cols, d.kernel_cols, d.stride_cols,
+ d.padding, &d.output_cols, &pad_dummy);
+ } while (!s.ok());
+ return d;
+}
+
+// Functions for comparing tensors.
+
+template <typename T>
+bool IsClose(const T& x, const T& y, double atol, double rtol) {
+ if (std::isnan(x) && std::isnan(y)) return true;
+ if (x == y) return true; // Allow inf == inf.
+ return fabs(x - y) < atol + rtol * fabs(x);
+}
+
+template <typename T>
+Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol,
+ double rtol) {
+ auto Tx = x.flat<T>();
+ auto Ty = y.flat<T>();
+ for (int i = 0; i < Tx.size(); ++i) {
+ if (!IsClose(Tx(i), Ty(i), atol, rtol)) {
+ return errors::InvalidArgument(strings::StrCat(
+ i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i),
+ ". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol,
+ " rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i))));
+ }
+ }
+ return Status::OK();
+}
+
+template <typename T>
+Status TensorsAreEqualImpl(const Tensor& x, const Tensor& y) {
+ auto Tx = x.flat<T>();
+ auto Ty = y.flat<T>();
+ for (int i = 0; i < Tx.size(); ++i) {
+ if (Tx(i) != Ty(i)) {
+ return errors::InvalidArgument(strings::StrCat(
+ i, "-th tensor element isn't equal: ", Tx(i), " vs. ", Ty(i),
+ ". x = ", x.DebugString(), "y = ", y.DebugString()));
+ }
+ }
+ return Status::OK();
+}
+
+// Tests if "x" and "y" are tensors of the same type, same shape, and with
+// close values. For floating-point tensors, the element-wise difference between
+// x and y must no more than atol + rtol * abs(x). For non-floating-point
+// tensors the values must match exactly.
+Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol,
+ double rtol) {
+ if (a.dtype() != b.dtype()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Tensors have different types: ", DataTypeString(a.dtype()), " and ",
+ DataTypeString(b.dtype())));
+ }
+ if (!a.IsSameSize(b)) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Tensors have different shapes: ", a.shape().DebugString(), " and ",
+ b.shape().DebugString()));
+ }
+
+ switch (a.dtype()) {
+ case DT_FLOAT:
+ return TensorsAreCloseImpl<float>(a, b, atol, rtol);
+ case DT_DOUBLE:
+ return TensorsAreCloseImpl<double>(a, b, atol, rtol);
+ case DT_INT32:
+ return TensorsAreEqualImpl<int32>(a, b);
+ case DT_INT64:
+ return TensorsAreEqualImpl<int64>(a, b);
+ case DT_BOOL:
+ return TensorsAreEqualImpl<bool>(a, b);
+ default:
+ LOG(FATAL) << "Unexpected type : " << DataTypeString(a.dtype());
+ }
+}
+
+void OpTest::ExpectTfAndXlaOutputsAreClose(const OpTestBuilder& builder,
+ double atol, double rtol) {
+ string cpu_device = DeviceTypeToDeviceName(DEVICE_CPU);
+ DeviceType test_device_type(*tf_xla_test_device_ptr);
+ string test_device = DeviceTypeToDeviceName(test_device_type);
+ ++num_tests_;
+
+ GraphDef graph;
+ std::vector<string> expected_inputs, test_inputs;
+ std::vector<string> expected_fetches, test_fetches;
+ TF_ASSERT_OK(builder.BuildGraph(
+ strings::StrCat("test", num_tests_, "_expected"), cpu_device,
+ /* use_jit= */ false, &graph, /* test_node_def= */ nullptr,
+ &expected_inputs, &expected_fetches));
+
+ NodeDef* node_def;
+ TF_ASSERT_OK(builder.BuildGraph(strings::StrCat("test", num_tests_, "_test"),
+ test_device, tf_xla_test_use_jit, &graph,
+ &node_def, &test_inputs, &test_fetches));
+
+ // Check that there's a kernel corresponding to 'node_def' on the device under
+ // test.
+ Status status = FindKernelDef(test_device_type, *node_def, nullptr, nullptr);
+ if (!status.ok()) {
+ VLOG(1) << "Skipping test because there is no corresponding registered "
+ << "kernel on the test device: " << status;
+ return;
+ }
+
+ TF_ASSERT_OK(session_->Extend(graph));
+
+ const std::vector<Tensor>& input_tensors = builder.inputs();
+ if (VLOG_IS_ON(1)) {
+ for (const Tensor& input : input_tensors) {
+ VLOG(1) << "Input: " << input.DebugString();
+ }
+ }
+
+ std::vector<std::pair<string, Tensor>> expected_feeds(expected_inputs.size());
+ std::vector<std::pair<string, Tensor>> test_feeds(test_inputs.size());
+ ASSERT_EQ(input_tensors.size(), expected_inputs.size());
+ ASSERT_EQ(input_tensors.size(), test_inputs.size());
+
+ for (int i = 0; i < input_tensors.size(); ++i) {
+ expected_feeds[i] = {expected_inputs[i], input_tensors[i]};
+ test_feeds[i] = {test_inputs[i], input_tensors[i]};
+ }
+
+ std::vector<Tensor> expected_outputs, test_outputs;
+ VLOG(1) << "Running expected graph";
+ Status s =
+ session_->Run(expected_feeds, expected_fetches, {}, &expected_outputs);
+ if (!s.ok()) {
+ VLOG(1) << "Expected graph failed with status: " << s << ". Skipping test";
+ return;
+ }
+
+ VLOG(1) << "Running test graph";
+ TF_ASSERT_OK(session_->Run(test_feeds, test_fetches, {}, &test_outputs));
+
+ ASSERT_EQ(expected_outputs.size(), test_outputs.size());
+ for (int j = 0; s.ok() && j < test_outputs.size(); ++j) {
+ s = TensorsAreClose(expected_outputs[j], test_outputs[j], atol, rtol);
+ }
+ TF_EXPECT_OK(s);
+}
+
+// Helper that converts 'values' to an int32 or int64 Tensor.
+Tensor AsIntTensor(DataType dtype, const std::vector<int64>& values) {
+ switch (dtype) {
+ case DT_INT32: {
+ std::vector<int32> values32(values.begin(), values.end());
+ return test::AsTensor<int32>(values32);
+ }
+ case DT_INT64:
+ return test::AsTensor<int64>(values);
+ default:
+ CHECK(false);
+ }
+}
+
+TEST_F(OpTest, Abs) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Abs").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Add) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, AddN) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ int n = std::uniform_int_distribution<int>(1, 5)(generator());
+
+ auto shape = RandomDims();
+
+ OpTestBuilder builder("AddN");
+ builder.Attr("T", type);
+ builder.Attr("N", n);
+ for (int i = 0; i < n; ++i) {
+ builder.Input(RandomTensor(type, shape));
+ }
+ ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+TEST_F(OpTest, All) {
+ Repeatedly([this]() {
+ Tensor data = RandomTensor(DT_BOOL);
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("All").Input(data).Input(indices).Attr("keep_dims",
+ keep_dims));
+ });
+}
+
+TEST_F(OpTest, Any) {
+ Repeatedly([this]() {
+ Tensor data = RandomTensor(DT_BOOL);
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Any").Input(data).Input(indices).Attr("keep_dims",
+ keep_dims));
+ });
+}
+
+TEST_F(OpTest, AvgPool) {
+ Repeatedly([this]() {
+ std::uniform_int_distribution<int> random_int(1, 5);
+ int kernel_rows = random_int(generator()),
+ kernel_cols = random_int(generator());
+ int stride_rows = random_int(generator()),
+ stride_cols = random_int(generator());
+ string padding = Choose<string>({"SAME", "VALID"});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("AvgPool")
+ .Input(
+ RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows),
+ RandomDim(kernel_cols), RandomDim(1)}))
+ .Attr("T", DT_FLOAT)
+ .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
+ .Attr("strides", {1, stride_rows, stride_cols, 1})
+ .Attr("padding", padding)
+ .Attr("data_format", "NHWC"));
+ });
+ // TODO(phawkins): the CPU device only implements spatial pooling. Add tests
+ // for batch pooling when supported.
+}
+
+TEST_F(OpTest, AvgPoolGrad) {
+ Repeatedly([this]() {
+ int batch = RandomDim(1), features = RandomDim(1);
+ WindowedDims d = ChooseWindowedDims();
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("AvgPoolGrad")
+ .Input(test::AsTensor<int32>(
+ {batch, d.input_rows, d.input_cols, features}))
+ .Input(RandomTensor(
+ DT_FLOAT, {batch, d.output_rows, d.output_cols, features}))
+ .Attr("T", DT_FLOAT)
+ .Attr("ksize", {1, d.kernel_rows, d.kernel_cols, 1})
+ .Attr("strides", {1, d.stride_rows, d.stride_cols, 1})
+ .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
+ .Attr("data_format", "NHWC"));
+ });
+}
+
+TEST_F(OpTest, BatchMatMul) {
+ Repeatedly([this]() {
+ std::vector<int64> output_dims = RandomDims(2, 5, 0, 7);
+ int64 ndims = output_dims.size();
+ int64 inner_dim = RandomDim();
+ std::vector<int64> x_dims(output_dims), y_dims(output_dims);
+ x_dims[ndims - 1] = inner_dim;
+ y_dims[ndims - 2] = inner_dim;
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
+ .Input(RandomTensor(DT_FLOAT, x_dims))
+ .Input(RandomTensor(DT_FLOAT, y_dims))
+ .Attr("T", DT_FLOAT));
+
+ std::swap(x_dims[ndims - 1], x_dims[ndims - 2]);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
+ .Input(RandomTensor(DT_FLOAT, x_dims))
+ .Input(RandomTensor(DT_FLOAT, y_dims))
+ .Attr("T", DT_FLOAT)
+ .Attr("adj_x", true));
+
+ std::swap(y_dims[ndims - 1], y_dims[ndims - 2]);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
+ .Input(RandomTensor(DT_FLOAT, x_dims))
+ .Input(RandomTensor(DT_FLOAT, y_dims))
+ .Attr("T", DT_FLOAT)
+ .Attr("adj_x", true)
+ .Attr("adj_y", true));
+
+ std::swap(x_dims[ndims - 1], x_dims[ndims - 2]);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul")
+ .Input(RandomTensor(DT_FLOAT, x_dims))
+ .Input(RandomTensor(DT_FLOAT, y_dims))
+ .Attr("T", DT_FLOAT)
+ .Attr("adj_y", true));
+ });
+}
+
+TEST_F(OpTest, BiasAdd) {
+ Repeatedly([this]() {
+ auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
+ auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)});
+ // TODO(phawkins): test both data formats.
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BiasAdd").Input(x).Input(y).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, BiasAddGrad) {
+ Repeatedly([this]() {
+ auto x = RandomTensor(DT_FLOAT);
+ // TODO(phawkins): test both data formats.
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BiasAddGrad").Input(x).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, BiasAddV1) {
+ Repeatedly([this]() {
+ auto x = RandomTensor(DT_FLOAT, RandomDims(2, kDefaultMaxRank));
+ auto y = RandomTensor(DT_FLOAT, {x.dim_size(x.dims() - 1)});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("BiasAddV1").Input(x).Input(y).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, BroadcastGradientArgs) {
+ Repeatedly([this]() {
+ // TODO(phawkins): only int32 seems to be implemented in Tensorflow.
+ // DataType type = Choose<DataType>({DT_INT32, DT_INT64});
+ DataType type = DT_INT32;
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BroadcastGradientArgs")
+ .Input(AsIntTensor(type, dims.first))
+ .Input(AsIntTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Cast) {
+ Repeatedly([this]() {
+ DataType src_type, dst_type;
+ src_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
+ dst_type = Choose<DataType>({DT_INT32, DT_FLOAT, DT_BOOL});
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast")
+ .Input(RandomTensor(src_type))
+ .Attr("SrcT", src_type)
+ .Attr("DstT", dst_type));
+ });
+}
+
+TEST_F(OpTest, Ceil) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Ceil")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Concat) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ int n = std::uniform_int_distribution<int>(2, 5)(generator());
+
+ std::vector<int64> dims = RandomDims(1);
+ int concat_dim =
+ std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
+
+ OpTestBuilder builder("Concat");
+ builder.Input(test::AsScalar<int32>(concat_dim));
+ builder.Attr("T", type);
+ builder.Attr("N", n);
+ for (int i = 0; i < n; ++i) {
+ std::vector<int64> shape = dims;
+ shape[concat_dim] = RandomDim();
+ builder.Input(RandomTensor(type, shape));
+ }
+ ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+TEST_F(OpTest, ConcatOffset) {
+ Repeatedly([this]() {
+ int n = std::uniform_int_distribution<int>(2, 5)(generator());
+
+ std::vector<int64> dims = RandomDims(1);
+ int concat_dim =
+ std::uniform_int_distribution<int32>(0, dims.size() - 1)(generator());
+
+ OpTestBuilder builder("ConcatOffset");
+ builder.Input(test::AsScalar<int32>(concat_dim));
+ builder.Attr("N", n);
+ for (int i = 0; i < n; ++i) {
+ std::vector<int32> shape(dims.begin(), dims.end());
+ shape[concat_dim] = RandomDim();
+ builder.Input(test::AsTensor<int32>(shape));
+ }
+ ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+TEST_F(OpTest, Conv2D) {
+ Repeatedly([this]() {
+ WindowedDims d = ChooseWindowedDims();
+ std::uniform_int_distribution<int> random_int(1, 5);
+ int features_in = random_int(generator());
+ int features_out = random_int(generator());
+ Tensor data = RandomTensor(
+ DT_FLOAT, {RandomDim(), d.input_rows, d.input_cols, features_in});
+
+ Tensor kernel = RandomTensor(
+ DT_FLOAT, {d.kernel_rows, d.kernel_cols, features_in, features_out});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Conv2D")
+ .Input(data)
+ .Input(kernel)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, d.stride_rows, d.stride_cols, 1})
+ .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
+ .Attr("data_format", "NHWC"));
+ });
+}
+
+TEST_F(OpTest, Conv2DBackpropFilter) {
+ Repeatedly([this]() {
+ WindowedDims d = ChooseWindowedDims();
+ std::uniform_int_distribution<int> random_int(1, 5);
+ int features_in = random_int(generator());
+ int features_out = random_int(generator());
+ int32 batch = RandomDim();
+ Tensor activations = RandomTensor(
+ DT_FLOAT, {batch, d.input_rows, d.input_cols, features_in});
+ Tensor backprop = RandomTensor(
+ DT_FLOAT, {batch, d.output_rows, d.output_cols, features_out});
+ Tensor kernel_shape = test::AsTensor<int32>(
+ {d.kernel_rows, d.kernel_cols, features_in, features_out});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Conv2DBackpropFilter")
+ .Input(activations)
+ .Input(kernel_shape)
+ .Input(backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, d.stride_rows, d.stride_cols, 1})
+ .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
+ .Attr("data_format", "NHWC"));
+ });
+}
+
+TEST_F(OpTest, Conv2DBackpropInput) {
+ Repeatedly([this]() {
+ WindowedDims d = ChooseWindowedDims();
+ std::uniform_int_distribution<int> random_int(1, 5);
+ int features_in = random_int(generator());
+ int features_out = random_int(generator());
+ int32 batch = RandomDim();
+ Tensor in_shape =
+ test::AsTensor<int32>({batch, d.input_rows, d.input_cols, features_in});
+ Tensor backprop = RandomTensor(
+ DT_FLOAT, {batch, d.output_rows, d.output_cols, features_out});
+ Tensor kernel = RandomTensor(
+ DT_FLOAT, {d.kernel_rows, d.kernel_cols, features_in, features_out});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Conv2DBackpropInput")
+ .Input(in_shape)
+ .Input(kernel)
+ .Input(backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, d.stride_rows, d.stride_cols, 1})
+ .Attr("padding", d.padding == SAME ? "SAME" : "VALID")
+ .Attr("data_format", "NHWC"));
+ });
+}
+
+TEST_F(OpTest, Diag) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Diag")
+ .Input(RandomTensor(type, RandomDims(1)))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, DiagPart) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = RandomDims(1, 3);
+ // Duplicate the random dims.
+ std::vector<int64> doubled_dims(dims.size() * 2);
+ std::copy(dims.begin(), dims.end(), doubled_dims.begin());
+ std::copy(dims.begin(), dims.end(), doubled_dims.begin() + dims.size());
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DiagPart")
+ .Input(RandomTensor(type, doubled_dims))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Div) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, DynamicStitch) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ int n = std::uniform_int_distribution<int>(2, 5)(generator());
+ OpTestBuilder builder("DynamicStitch");
+ builder.Attr("T", type);
+ builder.Attr("N", n);
+ std::vector<std::vector<int64>> index_dims;
+ int size = 0;
+ // TODO(phawkins): the XLA implementation of DynamicStitch does not
+ // accept an empty set of indices.
+ do {
+ size = 0;
+ index_dims.clear();
+ for (int i = 0; i < n; ++i) {
+ std::vector<int64> dims = RandomDims(0, 3, 0, 5);
+ size += TensorShape(dims).num_elements();
+ index_dims.push_back(dims);
+ }
+ } while (size == 0);
+
+ // Shuffle the range of indices that cover the output.
+ // TODO(phawkins): The documentation for DynamicStitch doesn't require that
+ // the indices cover all positions of the output. The XLA implementation
+ // does so require. However, the native TF implementation leaves undefined
+ // values if we don't cover everything, so we can't really test that case
+ // anyway.
+ std::vector<int32> indices(size);
+ std::iota(indices.begin(), indices.end(), 0);
+ std::shuffle(indices.begin(), indices.end(), generator());
+
+ int pos = 0;
+ for (int i = 0; i < n; ++i) {
+ TensorShape shape(index_dims[i]);
+ Tensor t = test::AsTensor<int32>(
+ gtl::ArraySlice<int32>(indices, pos, shape.num_elements()), shape);
+ builder.Input(t);
+ pos += t.NumElements();
+ }
+
+ std::vector<int64> constant_dims = RandomDims(0, 3, 0, 5);
+ for (int i = 0; i < n; ++i) {
+ std::vector<int64> dims(index_dims[i].begin(), index_dims[i].end());
+ std::copy(constant_dims.begin(), constant_dims.end(),
+ std::back_inserter(dims));
+ Tensor t = RandomTensor(type, dims);
+ builder.Input(t);
+ }
+ ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+TEST_F(OpTest, Equal) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Exp) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Exp").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, ExpandDims) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor in = RandomTensor(type);
+ Tensor dim(DT_INT32, TensorShape());
+ std::uniform_int_distribution<int32> d(-1 - in.dims(), in.dims());
+ dim.scalar<int32>()() = d(generator());
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("ExpandDims").Input(in).Input(dim).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Fill) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor scalar = RandomTensor(type, {});
+ std::vector<int64> dims = RandomDims();
+ std::vector<int32> shape(dims.begin(), dims.end());
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Fill")
+ .Input(test::AsTensor<int32>(shape))
+ .Input(scalar)
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Floor) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Floor")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, FloorDiv) {
+ Repeatedly([this]() {
+ DataType type = DT_INT32;
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorDiv")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, FloorMod) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Greater) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, GreaterEqual) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Reciprocal) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reciprocal")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, L2Loss) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ // TODO(b/31644876): scalars currently crash.
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("L2Loss")
+ .Input(RandomTensor(type, RandomDims(1)))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Less) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, LessEqual) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, LinSpace) {
+ Repeatedly([this]() {
+ auto ToScalar = [](DataType type, int x) {
+ if (type == DT_INT32) return test::AsScalar<int32>(x);
+ return test::AsScalar<int64>(x);
+ };
+ std::uniform_int_distribution<int> distribution(-50, 50);
+ DataType type = Choose<DataType>({DT_INT32, DT_INT64});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LinSpace")
+ .Input(RandomTensor(DT_FLOAT, {}))
+ .Input(RandomTensor(DT_FLOAT, {}))
+ .Input(ToScalar(type, distribution(generator())))
+ .Attr("T", DT_FLOAT)
+ .Attr("Tidx", type));
+ });
+}
+
+TEST_F(OpTest, Log) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Log").Input(RandomTensor(DT_FLOAT)).Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, LogicalAnd) {
+ Repeatedly([this]() {
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LogicalAnd")
+ .Input(RandomTensor(DT_BOOL, dims.first))
+ .Input(RandomTensor(DT_BOOL, dims.second)));
+ });
+}
+
+TEST_F(OpTest, LogicalNot) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LogicalNot").Input(RandomTensor(DT_BOOL)));
+ });
+}
+
+TEST_F(OpTest, LogicalOr) {
+ Repeatedly([this]() {
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LogicalOr")
+ .Input(RandomTensor(DT_BOOL, dims.first))
+ .Input(RandomTensor(DT_BOOL, dims.second)));
+ });
+}
+
+TEST_F(OpTest, LogSoftmax) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("LogSoftmax")
+ .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2)))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, LRN) {
+ Repeatedly([this]() {
+ Tensor data;
+ // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
+ data = RandomTensor(DT_FLOAT, RandomDims(4, 4, 1, 8));
+ // CuDNN requires depth_radius > 0.
+ std::uniform_int_distribution<int> radius(1, data.dim_size(3));
+ std::uniform_real_distribution<float> coeff(0.01, 2.0);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRN")
+ .Input(data)
+ .Attr("T", DT_FLOAT)
+ .Attr("depth_radius", radius(generator()))
+ .Attr("bias", coeff(generator()))
+ .Attr("alpha", coeff(generator()))
+ .Attr("beta", coeff(generator())));
+ });
+}
+
+TEST_F(OpTest, LRNGrad) {
+ Repeatedly([this]() {
+ // TODO(b/31362467): Crashes with 0 dims on GPU. Re-enable when fixed.
+ std::vector<int64> dims = RandomDims(4, 4, 1, 8);
+ Tensor input_grads = RandomTensor(DT_FLOAT, dims);
+ Tensor input_image = RandomTensor(DT_FLOAT, dims);
+ Tensor output_image = RandomTensor(DT_FLOAT, dims);
+ // CuDNN requires depth_radius > 0.
+ std::uniform_int_distribution<int> radius(1, input_grads.dim_size(3));
+ std::uniform_real_distribution<float> coeff(0.0, 2.0);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LRNGrad")
+ .Input(input_grads)
+ .Input(input_image)
+ .Input(output_image)
+ .Attr("T", DT_FLOAT)
+ .Attr("depth_radius", radius(generator()))
+ .Attr("bias", coeff(generator()))
+ .Attr("alpha", coeff(generator()))
+ .Attr("beta", coeff(generator())));
+ });
+}
+
+TEST_F(OpTest, MatMul) {
+ Repeatedly([this]() {
+ int64 x = RandomDim();
+ int64 y = RandomDim();
+ int64 z = RandomDim();
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
+ .Input(RandomTensor(DT_FLOAT, {x, y}))
+ .Input(RandomTensor(DT_FLOAT, {y, z}))
+ .Attr("T", DT_FLOAT));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
+ .Input(RandomTensor(DT_FLOAT, {y, x}))
+ .Input(RandomTensor(DT_FLOAT, {y, z}))
+ .Attr("T", DT_FLOAT)
+ .Attr("transpose_a", true));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
+ .Input(RandomTensor(DT_FLOAT, {x, y}))
+ .Input(RandomTensor(DT_FLOAT, {z, y}))
+ .Attr("T", DT_FLOAT)
+ .Attr("transpose_b", true));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul")
+ .Input(RandomTensor(DT_FLOAT, {y, x}))
+ .Input(RandomTensor(DT_FLOAT, {z, y}))
+ .Attr("T", DT_FLOAT)
+ .Attr("transpose_a", true)
+ .Attr("transpose_b", true));
+ });
+}
+
+TEST_F(OpTest, MatrixDiag) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_BOOL, DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag")
+ .Input(RandomTensor(type, RandomDims(1)))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, MatrixDiagPart) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_BOOL, DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart")
+ .Input(RandomTensor(type, RandomDims(2)))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Max) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ Tensor data = RandomTensor(type);
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Max").Input(data).Input(indices).Attr("T", type).Attr(
+ "keep_dims", keep_dims));
+ });
+}
+
+TEST_F(OpTest, Maximum) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, MaxPool) {
+ Repeatedly([this]() {
+ std::uniform_int_distribution<int> random_int(1, 5);
+ int kernel_rows = random_int(generator()),
+ kernel_cols = random_int(generator());
+ int stride_rows = random_int(generator()),
+ stride_cols = random_int(generator());
+ string padding = Choose<string>({"SAME", "VALID"});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("MaxPool")
+ .Input(
+ RandomTensor(DT_FLOAT, {RandomDim(1), RandomDim(kernel_rows),
+ RandomDim(kernel_cols), RandomDim(1)}))
+ .Attr("T", DT_FLOAT)
+ .Attr("ksize", {1, kernel_rows, kernel_cols, 1})
+ .Attr("strides", {1, stride_rows, stride_cols, 1})
+ .Attr("padding", padding)
+ .Attr("data_format", "NHWC"));
+ });
+ // TODO(phawkins): test NCHW format (not supported by CPU)
+}
+
+TEST_F(OpTest, Mean) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ // TODO(phawkins): CPU and XLA differ output for reducing across a
+ // size-0 dimension (nan vs 0). For now, require size >= 1.
+ Tensor data = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 1));
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Mean").Input(data).Input(indices).Attr("T", type).Attr(
+ "keep_dims", keep_dims));
+ });
+}
+
+TEST_F(OpTest, Min) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ Tensor data = RandomTensor(type);
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Min").Input(data).Input(indices).Attr("T", type).Attr(
+ "keep_dims", keep_dims));
+ });
+}
+
+TEST_F(OpTest, Minimum) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Mod) {
+ Repeatedly([this]() {
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Mod")
+ .Input(RandomTensor(DT_INT32, dims.first))
+ .Input(RandomTensor(DT_INT32, dims.second))
+ .Attr("T", DT_INT32));
+ });
+}
+
+TEST_F(OpTest, Mul) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Neg) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Neg").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, NotEqual) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Pack) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ int n = std::uniform_int_distribution<int>(1, 5)(generator());
+
+ std::vector<int64> dims = RandomDims();
+ int num_dims = dims.size();
+ int axis = std::uniform_int_distribution<int32>(-num_dims - 1,
+ num_dims)(generator());
+
+ OpTestBuilder builder("Pack");
+ builder.Attr("T", type);
+ builder.Attr("N", n);
+ builder.Attr("axis", axis);
+ for (int i = 0; i < n; ++i) {
+ builder.Input(RandomTensor(type, dims));
+ }
+ ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+// TODO(b/31741898): crashes on GPU.
+TEST_F(OpTest, Pad) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor t = RandomTensor(type);
+
+ // TODO(b/31741996): re-enable DT_INT64 when bug is fixed.
+ // DataType tpaddings = Choose<DataType>({DT_INT32, DT_INT64});
+ DataType tpaddings = DT_INT32;
+ std::vector<int64> paddings_vec;
+ std::uniform_int_distribution<int> distribution(0, 7);
+ for (int i = 0; i < t.dims(); ++i) {
+ paddings_vec.push_back(distribution(generator()));
+ paddings_vec.push_back(distribution(generator()));
+ }
+ Tensor paddings;
+ CHECK(paddings.CopyFrom(AsIntTensor(tpaddings, paddings_vec),
+ TensorShape({t.dims(), 2})));
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Pad").Input(t).Input(paddings).Attr("T", type).Attr(
+ "Tpaddings", tpaddings));
+ });
+}
+
+TEST_F(OpTest, Pow) {
+ // TODO(phawkins): Feeding large DT_INT32 values to Pow() leads to
+ // nontermination.
+ Repeatedly([this]() {
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Pow")
+ .Input(RandomTensor(DT_FLOAT, dims.first))
+ .Input(RandomTensor(DT_FLOAT, dims.second))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Prod) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ Tensor data = RandomTensor(type);
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Prod").Input(data).Input(indices).Attr("T", type).Attr(
+ "keep_dims", keep_dims));
+ });
+}
+
+TEST_F(OpTest, Range) {
+ Repeatedly([this]() {
+ auto ToScalar = [](DataType type, int x) {
+ if (type == DT_INT32) return test::AsScalar<int32>(x);
+ if (type == DT_INT64) return test::AsScalar<int64>(x);
+ if (type == DT_FLOAT) return test::AsScalar<float>(x);
+ if (type == DT_DOUBLE) return test::AsScalar<double>(x);
+ LOG(FATAL) << "Unknown type " << DataTypeString(type);
+ };
+ std::uniform_int_distribution<int> distribution(-50, 50);
+ DataType tidx = Choose<DataType>({DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Range")
+ .Input(ToScalar(tidx, distribution(generator())))
+ .Input(ToScalar(tidx, distribution(generator())))
+ .Input(ToScalar(tidx, distribution(generator())))
+ .Attr("Tidx", tidx));
+ });
+}
+
+TEST_F(OpTest, Rank) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Rank").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, RealDiv) {
+ Repeatedly([this]() {
+ DataType type = DT_FLOAT;
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Relu) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Relu6) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Relu6Grad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims(1);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Relu6Grad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, ReluGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims(1);
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReluGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Reshape) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ std::vector<int64> dims = RandomDims();
+ std::bernoulli_distribution random_bool;
+ std::vector<int64> dims_before, dims_after;
+ for (std::vector<int64>* out : {&dims_before, &dims_after}) {
+ std::shuffle(dims.begin(), dims.end(), generator());
+ for (int64 dim : dims) {
+ // Either add the dimension as a new dimension or merge it with the
+ // previous dimension.
+ if (out->empty() || random_bool(generator())) {
+ out->push_back(dim);
+ } else {
+ out->back() *= dim;
+ }
+ }
+ }
+ Tensor data = RandomTensor(type, dims_before);
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Reshape")
+ .Input(data)
+ .Input(test::AsTensor<int32>(
+ std::vector<int32>(dims_after.begin(), dims_after.end())))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Rsqrt) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Rsqrt")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, RsqrtGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Shape) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Shape").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, ShapeN) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ int n = std::uniform_int_distribution<int>(1, 5)(generator());
+ OpTestBuilder builder("ShapeN");
+ builder.Attr("T", type);
+ builder.Attr("N", n);
+ for (int i = 0; i < n; ++i) {
+ builder.Input(RandomTensor(type));
+ }
+ ExpectTfAndXlaOutputsAreClose(builder);
+ });
+}
+
+TEST_F(OpTest, Sigmoid) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sigmoid")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, SigmoidGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Sign) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Sign").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Size) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Size").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Slice) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor data = RandomTensor(type);
+
+ std::vector<int32> begin(data.dims()), size(data.dims());
+ for (int i = 0; i < data.dims(); ++i) {
+ begin[i] = std::uniform_int_distribution<int32>(
+ 0, data.dim_size(i))(generator());
+ size[i] = std::uniform_int_distribution<int32>(
+ -1, data.dim_size(i) - begin[i])(generator());
+ }
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Slice")
+ .Input(data)
+ .Input(test::AsTensor<int32>(begin))
+ .Input(test::AsTensor<int32>(size))
+ .Attr("T", type)
+ .Attr("Index", DT_INT32));
+ });
+}
+
+TEST_F(OpTest, Softmax) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Softmax")
+ .Input(RandomTensor(DT_FLOAT, RandomDims(2, 2)))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Split) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ std::vector<int64> dims = RandomDims(1);
+ std::uniform_int_distribution<int> ud;
+ int32 dim = std::uniform_int_distribution<int32>(
+ 0, static_cast<int32>(dims.size()) - 1)(generator());
+ int n = std::uniform_int_distribution<int>(1, 5)(generator());
+ // Ensure 'dim' is evenly divisible by 'n'.
+ dims[dim] /= n;
+ dims[dim] *= n;
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Split")
+ .Input(test::AsScalar<int32>(dim))
+ .Input(RandomTensor(type, dims))
+ .Attr("T", type)
+ .Attr("num_split", n));
+ });
+}
+
+TEST_F(OpTest, Softplus) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Softplus")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, SoftplusGrad) {
+ Repeatedly([this]() {
+ std::vector<int64> dims = RandomDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SoftplusGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, SparseMatMul) {
+ Repeatedly([this]() {
+ int64 x = RandomDim();
+ int64 y = RandomDim();
+ int64 z = RandomDim();
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
+ .Input(RandomTensor(DT_FLOAT, {x, y}))
+ .Input(RandomTensor(DT_FLOAT, {y, z}))
+ .Attr("Ta", DT_FLOAT)
+ .Attr("Tb", DT_FLOAT));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
+ .Input(RandomTensor(DT_FLOAT, {y, x}))
+ .Input(RandomTensor(DT_FLOAT, {y, z}))
+ .Attr("Ta", DT_FLOAT)
+ .Attr("Tb", DT_FLOAT)
+ .Attr("transpose_a", true));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
+ .Input(RandomTensor(DT_FLOAT, {x, y}))
+ .Input(RandomTensor(DT_FLOAT, {z, y}))
+ .Attr("Ta", DT_FLOAT)
+ .Attr("Tb", DT_FLOAT)
+ .Attr("transpose_b", true));
+
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SparseMatMul")
+ .Input(RandomTensor(DT_FLOAT, {y, x}))
+ .Input(RandomTensor(DT_FLOAT, {z, y}))
+ .Attr("Ta", DT_FLOAT)
+ .Attr("Tb", DT_FLOAT)
+ .Attr("transpose_a", true)
+ .Attr("transpose_b", true));
+ });
+}
+
+TEST_F(OpTest, Sqrt) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sqrt")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, SquaredDifference) {
+ Repeatedly([this]() {
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("SquaredDifference")
+ .Input(RandomTensor(DT_FLOAT, dims.first))
+ .Input(RandomTensor(DT_FLOAT, dims.second))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Square) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Square").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Squeeze) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor t = RandomTensor(type, RandomDims(0, kDefaultMaxRank, 0, 5));
+ std::bernoulli_distribution random_bool;
+ std::vector<int> squeeze_dims;
+ for (int i = 0; i < t.dims(); ++i) {
+ if (t.dim_size(i) == 1 && random_bool(generator())) {
+ squeeze_dims.push_back(i);
+ }
+ }
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Squeeze")
+ .Input(t)
+ .Attr("squeeze_dims", squeeze_dims)
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Sub) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Sum) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ Tensor data = RandomTensor(type);
+ Tensor indices = RandomReductionIndices(data.dims());
+ bool keep_dims = Choose<bool>({false, true});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("Sum").Input(data).Input(indices).Attr("T", type).Attr(
+ "keep_dims", keep_dims));
+ });
+}
+
+TEST_F(OpTest, StridedSlice) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor data = RandomTensor(type);
+
+ std::vector<int32> begin(data.dims()), end(data.dims());
+ std::vector<int32> strides(data.dims());
+ for (int i = 0; i < data.dims(); ++i) {
+ begin[i] = std::uniform_int_distribution<int32>(
+ -2 * data.dim_size(i), 2 * data.dim_size(i))(generator());
+ end[i] = std::uniform_int_distribution<int32>(
+ -2 * data.dim_size(i), 2 * data.dim_size(i))(generator());
+ // TODO(b/31360685): support strides other than 1 or -1
+ strides[i] = std::bernoulli_distribution()(generator()) ? 1 : -1;
+ }
+ int64 max_bitmask = (1LL << data.dims()) - 1;
+ std::uniform_int_distribution<int64> bitmask_distribution(0, max_bitmask);
+ int64 begin_mask = bitmask_distribution(generator());
+ int64 end_mask = bitmask_distribution(generator());
+
+ // Create a ellipsis bitmask with at most one 1 bit set.
+ int64 ellipsis_mask = 0;
+ if (data.dims() > 0 && std::bernoulli_distribution()(generator())) {
+ int ellipsis_pos =
+ std::uniform_int_distribution<int>(0, data.dims() - 1)(generator());
+ ellipsis_mask = 1LL << ellipsis_pos;
+ }
+
+ int64 new_axis_mask = bitmask_distribution(generator());
+ int64 shrink_axis_mask = bitmask_distribution(generator());
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("StridedSlice")
+ .Input(data)
+ .Input(test::AsTensor<int32>(begin))
+ .Input(test::AsTensor<int32>(end))
+ .Input(test::AsTensor<int32>(strides))
+ .Attr("T", type)
+ .Attr("Index", DT_INT32)
+ .Attr("begin_mask", begin_mask)
+ .Attr("end_mask", end_mask)
+ .Attr("ellipsis_mask", ellipsis_mask)
+ .Attr("new_axis_mask", new_axis_mask)
+ .Attr("shrink_axis_mask", shrink_axis_mask));
+ });
+}
+
+TEST_F(OpTest, StridedSliceGrad) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+
+ // Dimensions of the forward input.
+ std::vector<int64> dims = RandomDims();
+
+ std::vector<int64> begin(dims.size()), end(dims.size());
+ std::vector<int64> strides(dims.size());
+ for (int i = 0; i < dims.size(); ++i) {
+ begin[i] = std::uniform_int_distribution<int64>(-2 * dims[i],
+ 2 * dims[i])(generator());
+ end[i] = std::uniform_int_distribution<int64>(-2 * dims[i],
+ 2 * dims[i])(generator());
+ strides[i] = std::uniform_int_distribution<int64>(
+ -2 * dims[i], 2 * dims[i])(generator());
+ }
+ int64 max_bitmask = (1LL << dims.size()) - 1;
+ std::uniform_int_distribution<int64> bitmask_distribution(0, max_bitmask);
+ int64 begin_mask = bitmask_distribution(generator());
+ int64 end_mask = bitmask_distribution(generator());
+
+ // Create a ellipsis bitmask with at most one 1 bit set.
+ int64 ellipsis_mask = 0;
+ if (!dims.empty() && std::bernoulli_distribution()(generator())) {
+ int ellipsis_pos =
+ std::uniform_int_distribution<int>(0, dims.size() - 1)(generator());
+ ellipsis_mask = 1LL << ellipsis_pos;
+ }
+
+ int64 new_axis_mask = bitmask_distribution(generator());
+ int64 shrink_axis_mask = bitmask_distribution(generator());
+
+ // TODO(phawkins): use shape inference for the forward op to compute the
+ // gradient shape for the backward op. At present, there is a low
+ // probability of the golden op succeeding.
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("StridedSliceGrad")
+ .Input(test::AsTensor<int64>(dims))
+ .Input(test::AsTensor<int64>(begin))
+ .Input(test::AsTensor<int64>(end))
+ .Input(test::AsTensor<int64>(strides))
+ .Input(RandomTensor(type, RandomDims(1)))
+ .Attr("T", type)
+ .Attr("Index", DT_INT64)
+ .Attr("begin_mask", begin_mask)
+ .Attr("end_mask", end_mask)
+ .Attr("ellipsis_mask", ellipsis_mask)
+ .Attr("new_axis_mask", new_axis_mask)
+ .Attr("shrink_axis_mask", shrink_axis_mask));
+ });
+}
+
+TEST_F(OpTest, Tanh) {
+ Repeatedly([this]() {
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tanh")
+ .Input(RandomTensor(DT_FLOAT))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, TanhGrad) {
+ Repeatedly([this]() {
+ auto dims = RandomDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad")
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Input(RandomTensor(DT_FLOAT, dims))
+ .Attr("T", DT_FLOAT));
+ });
+}
+
+TEST_F(OpTest, Tile) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor t = RandomTensor(type, RandomDims(1));
+ std::vector<int32> multiples(t.dims());
+ for (int i = 0; i < t.dims(); ++i) {
+ multiples[i] = std::uniform_int_distribution<int>(1, 3)(generator());
+ }
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Tile")
+ .Input(t)
+ .Input(test::AsTensor<int32>(multiples))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, Transpose) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>(kAllXlaTypes);
+ Tensor data = RandomTensor(type);
+ std::vector<int32> perm(data.dims());
+ std::iota(perm.begin(), perm.end(), 0);
+ std::shuffle(perm.begin(), perm.end(), generator());
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Transpose")
+ .Input(data)
+ .Input(test::AsTensor<int32>(perm))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, TruncateDiv) {
+ Repeatedly([this]() {
+ DataType type = DT_INT32;
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateDiv")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, TruncateMod) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ auto dims = BroadcastableDims();
+ ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod")
+ .Input(RandomTensor(type, dims.first))
+ .Input(RandomTensor(type, dims.second))
+ .Attr("T", type));
+ });
+}
+
+TEST_F(OpTest, ZerosLike) {
+ Repeatedly([this]() {
+ DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});
+ ExpectTfAndXlaOutputsAreClose(
+ OpTestBuilder("ZerosLike").Input(RandomTensor(type)).Attr("T", type));
+ });
+}
+
+} // anonymous namespace
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ tensorflow::tf_xla_test_device_ptr = new tensorflow::string("GPU");
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag(
+ "tf_xla_random_seed", &tensorflow::tf_xla_random_seed,
+ "Random seed to use for XLA tests. <= 0 means choose a seed "
+ "nondetermistically."),
+ // TODO(phawkins): it might make more sense to run each test up to a
+ // configurable time bound.
+ tensorflow::Flag("tf_xla_test_repetitions",
+ &tensorflow::tf_xla_test_repetitions,
+ "Number of repetitions for each test."),
+ tensorflow::Flag("tf_xla_test_device", tensorflow::tf_xla_test_device_ptr,
+ "Tensorflow device type to use for test"),
+ tensorflow::Flag("tf_xla_test_use_jit", &tensorflow::tf_xla_test_use_jit,
+ "Use JIT compilation for the operator under test"),
+ };
+ tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ // XLA devices register kernels at construction time; create and destroy all
+ // known devices to make sure the kernels are registered.
+ std::vector<tensorflow::Device*> devices;
+ TF_CHECK_OK(tensorflow::DeviceFactory::AddDevices(
+ tensorflow::SessionOptions(), "", &devices));
+ for (tensorflow::Device* device : devices) {
+ delete device;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py
new file mode 100644
index 0000000000..efda2cc207
--- /dev/null
+++ b/tensorflow/compiler/tests/reduce_ops_test.py
@@ -0,0 +1,125 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for reduction operators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class ReduceOpsTest(XLATestCase):
+
+ def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
+ rtol=1e-4, atol=1e-4):
+ """Tests that the output of 'tf_reduce_fn' matches numpy's output."""
+
+ for test_input in test_inputs:
+ with self.test_session() as sess:
+ with self.test_scope():
+ a = array_ops.placeholder(dtype)
+ index = array_ops.placeholder(dtypes.int32)
+ out = tf_reduce_fn(a, index)
+ result = sess.run(out, {a: test_input, index: [0]})
+ self.assertAllClose(result, np_reduce_fn(test_input, axis=0),
+ rtol=rtol, atol=atol)
+
+ result = sess.run(out, {a: test_input, index: [1]})
+ self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
+ rtol=rtol, atol=atol)
+
+ result = sess.run(out, {a: test_input, index: [-1]})
+ self.assertAllClose(result, np_reduce_fn(test_input, axis=1),
+ rtol=rtol, atol=atol)
+
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
+ sess.run(out, {a: test_input, index: [-33]})
+
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
+ sess.run(out, {a: test_input, index: [2]})
+
+ FLOAT_DATA = [
+ np.zeros(shape=(2, 0)),
+ np.zeros(shape=(0, 30)),
+ np.arange(1, 7).reshape(2, 3),
+ np.arange(-10, -4).reshape(2, 3),
+ np.arange(-4, 2).reshape(2, 3),
+ ]
+ NONEMPTY_FLOAT_DATA = [
+ np.arange(1, 7).reshape(2, 3),
+ np.arange(-10, -4).reshape(2, 3),
+ np.arange(-4, 2).reshape(2, 3),
+ ]
+ BOOL_DATA = [
+ np.array([], dtype=np.bool).reshape(2, 0),
+ np.array([], dtype=np.bool).reshape(0, 3),
+ np.array([[False, True, False], [True, True, False]]),
+ ]
+
+ def testReduceSum(self):
+ self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
+ self.FLOAT_DATA)
+
+ def testReduceProd(self):
+ self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
+ self.FLOAT_DATA)
+
+ def testReduceMin(self):
+
+ def reference_min(inp, axis):
+ """Wrapper around np.amin that returns +infinity for an empty input."""
+ if inp.shape[axis] == 0:
+ return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('inf'))
+ return np.amin(inp, axis)
+
+ self._testReduction(math_ops.reduce_min, reference_min, np.float32,
+ self.FLOAT_DATA)
+
+ def testReduceMax(self):
+
+ def reference_max(inp, axis):
+ """Wrapper around np.amax that returns -infinity for an empty input."""
+ if inp.shape[axis] == 0:
+ return np.full(inp.shape[0:axis] + inp.shape[axis + 1:], float('-inf'))
+ return np.amax(inp, axis)
+
+ self._testReduction(math_ops.reduce_max, reference_max, np.float32,
+ self.FLOAT_DATA)
+
+ def testReduceMean(self):
+ # TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
+ # reducing across zero inputs.
+ self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
+ self.NONEMPTY_FLOAT_DATA)
+
+ def testReduceAll(self):
+ self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)
+
+ def testReduceAny(self):
+ self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA)
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
new file mode 100644
index 0000000000..22024f4511
--- /dev/null
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -0,0 +1,110 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for ternary operators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import googletest
+
+
+class TernaryOpsTest(XLATestCase):
+
+ def _testTernary(self, op, a, b, c, expected):
+ with self.test_session() as session:
+ with self.test_scope():
+ pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
+ pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
+ pc = array_ops.placeholder(dtypes.as_dtype(c.dtype), c.shape, name="c")
+ output = op(pa, pb, pc)
+ result = session.run(output, {pa: a, pb: b, pc: c})
+ self.assertAllClose(result, expected, rtol=1e-3)
+
+ def testLinspace(self):
+ self._testTernary(
+ math_ops.linspace,
+ np.float32(1),
+ np.float32(2),
+ np.int32(1),
+ expected=np.array([1], dtype=np.float32))
+ self._testTernary(
+ math_ops.linspace,
+ np.float32(1),
+ np.float32(4),
+ np.int32(3),
+ expected=np.array([1, 2.5, 4], dtype=np.float32))
+
+ def testRange(self):
+ self._testTernary(
+ math_ops.range,
+ np.int32(1),
+ np.int32(2),
+ np.int32(1),
+ expected=np.array([1], dtype=np.int32))
+ self._testTernary(
+ math_ops.range,
+ np.int32(1),
+ np.int32(7),
+ np.int32(2),
+ expected=np.array([1, 3, 5], dtype=np.int32))
+
+ def testSelect(self):
+ self._testTernary(
+ array_ops.where,
+ np.array(0, dtype=np.bool),
+ np.array(2, dtype=np.float32),
+ np.array(7, dtype=np.float32),
+ expected=np.array(7, dtype=np.float32))
+
+ self._testTernary(
+ array_ops.where,
+ np.array([0, 1, 1, 0], dtype=np.bool),
+ np.array([1, 2, 3, 4], dtype=np.float32),
+ np.array([5, 6, 7, 8], dtype=np.float32),
+ expected=np.array([5, 2, 3, 8], dtype=np.float32))
+
+ self._testTernary(
+ array_ops.where,
+ np.array([0, 1, 0], dtype=np.bool),
+ np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
+ np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
+ expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
+
+ def testSlice(self):
+ for dtype in self.numeric_types:
+ self._testTernary(
+ array_ops.slice,
+ np.array([[], [], []], dtype=dtype),
+ np.array([1, 0], dtype=np.int32),
+ np.array([2, 0], dtype=np.int32),
+ expected=np.array([[], []], dtype=dtype))
+
+ self._testTernary(
+ array_ops.slice,
+ np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype),
+ np.array([0, 1], dtype=np.int32),
+ np.array([2, 1], dtype=np.int32),
+ expected=np.array([[2], [5]], dtype=dtype))
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
new file mode 100644
index 0000000000..33e0424e60
--- /dev/null
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -0,0 +1,346 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for XLA JIT compiler."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import googletest
+
+
+class UnaryOpsTest(XLATestCase):
+ """Test cases for unary operators."""
+
+ def _testUnary(self, op, inp, expected, equality_test=None):
+ with self.test_session() as session:
+ with self.test_scope():
+ pinp = array_ops.placeholder(
+ dtypes.as_dtype(inp.dtype), inp.shape, name="a")
+ output = op(pinp)
+ result = session.run(output, {pinp: inp})
+ if equality_test is None:
+ equality_test = self.assertAllClose
+ equality_test(result, expected, rtol=1e-3)
+
+ def ListsAreClose(self, result, expected, rtol):
+ """Tests closeness of two lists of floats."""
+ self.assertEqual(len(result), len(expected))
+ for i in range(len(result)):
+ self.assertAllClose(result[i], expected[i], rtol)
+
+ def testAllTypeOps(self):
+ for dtype in self.numeric_types:
+ self._testUnary(
+ array_ops.diag,
+ np.array([1, 2, 3, 4], dtype=dtype),
+ np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
+ dtype=dtype))
+ self._testUnary(
+ array_ops.diag_part,
+ np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
+ np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype))
+
+ self._testUnary(
+ array_ops.identity,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([[-1, 1]], dtype=dtype))
+
+ self._testUnary(
+ array_ops.matrix_diag,
+ np.array([[1, 2], [3, 4]], dtype=dtype),
+ np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
+ self._testUnary(
+ array_ops.matrix_diag_part,
+ np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype),
+ np.array([[0, 5], [8, 13], [16, 21]], dtype=dtype))
+
+ self._testUnary(
+ array_ops.prevent_gradient,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([[-1, 1]], dtype=dtype))
+
+ self._testUnary(
+ array_ops.squeeze,
+ np.array([[[[[]]]]], dtype=dtype),
+ expected=np.array([], dtype=dtype))
+ self._testUnary(
+ array_ops.squeeze,
+ np.array([[[1], [2]]], dtype=dtype),
+ expected=np.array([1, 2], dtype=dtype))
+ self._testUnary(
+ array_ops.squeeze,
+ np.array([[[1]], [[2]]], dtype=dtype),
+ expected=np.array([1, 2], dtype=dtype))
+ self._testUnary(
+ array_ops.squeeze,
+ np.array([[[1, 2], [3, 4]]], dtype=dtype),
+ expected=np.array([[1, 2], [3, 4]], dtype=dtype))
+
+ self._testUnary(
+ array_ops.stop_gradient,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([[-1, 1]], dtype=dtype))
+
+ def testFloatOps(self):
+ for dtype in self.float_types:
+ self._testUnary(
+ math_ops.ceil,
+ np.array([[-1.7, 1.2]], dtype=dtype),
+ expected=np.array([[-1, 2]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.exp,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([[0.36787945, 2.7182817]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.floor,
+ np.array([[-1.7, 1.2]], dtype=dtype),
+ expected=np.array([[-2, 1]], dtype=dtype))
+
+ # Tests for tf.nn ops.
+ self._testUnary(
+ nn_ops.l2_loss, np.array([[[]]], dtype=dtype), expected=dtype(0))
+
+ # TODO(b/31644876): enable this test case when fixed.
+ # self._testUnary(tf.nn.l2_loss, dtype(4), dtype(10))
+
+ self._testUnary(
+ nn_ops.l2_loss, np.array([[-2, 4]], dtype=dtype), expected=dtype(10))
+
+ self._testUnary(
+ math_ops.reciprocal,
+ np.array([[1, 2]], dtype=dtype),
+ expected=np.array([[1, 0.5]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.log,
+ np.array([[1, 2]], dtype=dtype),
+ expected=np.array([[0, 0.69314718]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.rsqrt,
+ np.array([[4, 16]], dtype=dtype),
+ expected=np.array([[0.5, 0.25]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.sigmoid,
+ np.array(
+ [[1, 1, 1, 1],
+ [1, 2, 3, 4]],
+ dtype=dtype),
+ expected=np.array(
+ [[0.7310586, 0.7310586, 0.7310586, 0.7310586],
+ [0.7310586, 0.880797, 0.95257413, 0.98201376]],
+ dtype=dtype))
+
+ self._testUnary(
+ math_ops.sqrt,
+ np.array([[4, 9]], dtype=dtype),
+ expected=np.array([[2, 3]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.tanh,
+ np.array(
+ [[1, 1, 1, 1],
+ [1, 2, 3, 4]],
+ dtype=dtype),
+ expected=np.array(
+ [[0.76159418, 0.76159418, 0.76159418, 0.76159418],
+ [0.76159418, 0.96402758, 0.99505478, 0.99932933]],
+ dtype=dtype))
+
+ self._testUnary(
+ nn_ops.log_softmax,
+ np.array(
+ [[1, 1, 1, 1],
+ [1, 2, 3, 4]],
+ dtype=dtype),
+ expected=np.array(
+ [[-1.3862944, -1.3862944, -1.3862944, -1.3862944],
+ [-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
+ dtype=dtype))
+
+ self._testUnary(
+ nn_ops.relu,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([[0, 1]], dtype=dtype))
+
+ self._testUnary(
+ nn_ops.relu6,
+ np.array([[-0.05, 6.05, 5]], dtype=dtype),
+ expected=np.array([[0, 6, 5]], dtype=dtype))
+
+ self._testUnary(
+ nn_ops.softmax,
+ np.array(
+ [[1, 1, 1, 1],
+ [1, 2, 3, 4]],
+ dtype=dtype),
+ expected=np.array(
+ [[0.25, 0.25, 0.25, 0.25],
+ [0.032058604, 0.087144323, 0.23688284, 0.64391428]],
+ dtype=dtype))
+
+ self._testUnary(
+ nn_ops.softplus,
+ np.array([[-2, 0, 8]], dtype=dtype),
+ expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype))
+
+ def testNumericOps(self):
+ for dtype in self.numeric_types:
+ self._testUnary(
+ math_ops.abs,
+ np.array([[2, -1]], dtype=dtype),
+ expected=np.array([[2, 1]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.neg,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([[1, -1]], dtype=dtype))
+
+ self._testUnary(
+ math_ops.square,
+ np.array([[-2, 3]], dtype=dtype),
+ expected=np.array([[4, 9]], dtype=dtype))
+
+ self._testUnary(
+ array_ops.zeros_like,
+ np.array([[4, 3], [2, 1]], dtype=dtype),
+ expected=np.array([[0, 0], [0, 0]], dtype=dtype))
+
+ def testLogicalOps(self):
+ self._testUnary(
+ math_ops.logical_not,
+ np.array([[True, False], [False, True]], dtype=np.bool),
+ expected=np.array([[False, True], [True, False]], dtype=np.bool))
+
+ def testBiasAddGrad(self):
+ self._testUnary(
+ gen_nn_ops.bias_add_grad,
+ np.array([[1., 2.], [3., 4.]], dtype=np.float32),
+ expected=np.array([4., 6.], dtype=np.float32))
+
+ self._testUnary(lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
+ np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
+ dtype=np.float32),
+ expected=np.array([10., 26.], dtype=np.float32))
+
+ def testCast(self):
+ shapes = [[], [4], [2, 3], [2, 0, 4]]
+ types = [dtypes.bool, dtypes.int32, dtypes.float32]
+ for shape in shapes:
+ for src_type in types:
+ for dst_type in types:
+ src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype)
+ src = src.reshape(shape)
+
+ dst = src.astype(dst_type.as_numpy_dtype)
+ self._testUnary(
+ lambda x, dst_type=dst_type: math_ops.cast(x, dst_type),
+ src,
+ expected=dst)
+
+ def testInvertPermutation(self):
+ self._testUnary(
+ array_ops.invert_permutation,
+ np.array([1, 2, 0], np.int32),
+ expected=np.array([2, 0, 1], dtype=np.int32))
+
+ def testRank(self):
+ rank_op = lambda x: array_ops.rank_internal(x, optimize=False)
+ for dtype in self.numeric_types:
+ self._testUnary(rank_op, dtype(7), expected=np.int32(0))
+ self._testUnary(
+ rank_op, np.array(
+ [[], []], dtype=dtype), expected=np.int32(2))
+ self._testUnary(
+ rank_op, np.array(
+ [-1, 1], dtype=dtype), expected=np.int32(1))
+ self._testUnary(
+ rank_op, np.array(
+ [[-1, 1]], dtype=dtype), expected=np.int32(2))
+ self._testUnary(
+ rank_op,
+ np.array([[-1], [1], [4]], dtype=dtype),
+ expected=np.int32(2))
+
+ def testShape(self):
+ shape_op = lambda x: array_ops.shape_internal(x, optimize=False)
+ for dtype in self.numeric_types:
+ self._testUnary(shape_op, dtype(7), expected=np.array([], dtype=np.int32))
+ self._testUnary(
+ shape_op,
+ np.array([[], []], dtype=dtype),
+ expected=np.array([2, 0], dtype=np.int32))
+ self._testUnary(
+ shape_op,
+ np.array([-1, 1], dtype=dtype),
+ expected=np.array([2], dtype=np.int32))
+ self._testUnary(
+ shape_op,
+ np.array([[-1, 1]], dtype=dtype),
+ expected=np.array([1, 2], dtype=np.int32))
+ self._testUnary(
+ shape_op,
+ np.array([[-1], [1], [4]], dtype=dtype),
+ expected=np.array([3, 1], dtype=np.int32))
+
+ def testSize(self):
+ size_op = lambda x: array_ops.size_internal(x, optimize=False)
+ for dtype in self.numeric_types:
+ self._testUnary(size_op, dtype(7), expected=np.int32(1))
+ self._testUnary(
+ size_op, np.array([[], []], dtype=dtype), expected=np.int32(0))
+ self._testUnary(
+ size_op, np.array([-1, 1], dtype=dtype), expected=np.int32(2))
+ self._testUnary(
+ size_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
+ self._testUnary(
+ size_op,
+ np.array([[-1], [1], [4]], dtype=dtype),
+ expected=np.int32(3))
+
+ def testUnpack(self):
+ self._testUnary(
+ array_ops.unpack,
+ np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
+ expected=[
+ np.array([1., 2.], dtype=np.float32),
+ np.array([3., 4.], dtype=np.float32),
+ np.array([5., 6.], dtype=np.float32),
+ ],
+ equality_test=self.ListsAreClose)
+
+ self._testUnary(lambda x: array_ops.unstack(x, axis=1),
+ np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
+ expected=[
+ np.array([1., 3., 5.], dtype=np.float32),
+ np.array([2., 4., 6.], dtype=np.float32),
+ ],
+ equality_test=self.ListsAreClose)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
new file mode 100644
index 0000000000..1388a892ba
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -0,0 +1,81 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for XLA devices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class XlaDeviceTest(test.TestCase):
+
+ def testCopies(self):
+ """Tests that copies between GPU and XLA devices work."""
+ if not test.is_gpu_available():
+ return
+
+ with session_lib.Session() as sess:
+ x = array_ops.placeholder(dtypes.float32, [2])
+ with ops.device("GPU"):
+ y = x * 2
+ with ops.device("device:XLA_CPU:0"):
+ z = y * y
+ with ops.device("GPU"):
+ w = y + z
+ result = sess.run(w, {x: [1.5, 0.5]})
+ self.assertAllClose(result, [12., 2.], rtol=1e-3)
+
+ def testLoops(self):
+ """Tests that loops work on XLA devices."""
+
+ with session_lib.Session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ with ops.device("device:XLA_CPU:0"):
+ c = lambda i, _: math_ops.less(i, 5)
+ b = lambda i, x: (i + 1, x * 2.0 + 1.0)
+ _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
+
+ result = session.run(y, {x: np.float32(2)})
+ self.assertAllClose(result, np.float32(95), rtol=1e-3)
+
+ def testCond(self):
+ """Tests that tf.cond works on XLA devices."""
+
+ with session_lib.Session() as session:
+ x = array_ops.placeholder(dtypes.float32)
+ y = array_ops.placeholder(dtypes.float32)
+ c = array_ops.placeholder(dtypes.bool)
+ with ops.device("device:XLA_CPU:0"):
+ z = x + 1.0
+ w = control_flow_ops.cond(c, lambda: z, lambda: y)
+ t = math_ops.add(z, w)
+
+ result = session.run(t, {x: np.float32(2), y: np.float32(4), c: True})
+ self.assertAllClose(result, np.float32(6), rtol=1e-3)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
new file mode 100644
index 0000000000..b72e7c9713
--- /dev/null
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -0,0 +1,148 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Definition of XLA test case."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+import re
+
+from tensorflow.contrib.compiler import jit
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging as logging
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('test_device', None,
+ 'Tensorflow device on which to place operators under test')
+flags.DEFINE_string('types', None, 'Types to test. Comma-separated list.')
+flags.DEFINE_string('disabled_manifest', None,
+ 'Path to a file with a list of tests that should not run.')
+
+
+class XLATestCase(test.TestCase):
+ """XLA test cases are parameterized test cases."""
+
+ def __init__(self, method_name='runTest'):
+ super(XLATestCase, self).__init__(method_name)
+ self.device = FLAGS.test_device
+ self.has_custom_call = (self.device == 'XLA_CPU')
+ self.all_tf_types = [
+ dtypes.DType(types_pb2.DataType.Value(name))
+ for name in FLAGS.types.split(',')
+ ]
+ self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types]
+ self.int_types = [
+ dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_integer
+ ]
+ self.float_types = [
+ dtype.as_numpy_dtype for dtype in self.all_tf_types if dtype.is_floating
+ ]
+ self.numeric_types = self.int_types + self.float_types
+
+ # Parse the manifest file, if any, into a regex identifying tests to
+ # disable
+ self.disabled_regex = None
+ if FLAGS.disabled_manifest is not None:
+ comments_re = re.compile('#.*$')
+ manifest_file = open(FLAGS.disabled_manifest, 'r')
+ lines = manifest_file.read().splitlines()
+ lines = [comments_re.sub('', l).strip() for l in lines]
+ self.disabled_regex = re.compile('|'.join(lines))
+ manifest_file.close()
+
+ def setUp(self):
+ name = '{}.{}'.format(type(self).__name__, self._testMethodName)
+ if self.disabled_regex is not None and self.disabled_regex.match(name):
+ logging.info('Disabled test case: %s', name)
+ self.skipTest('{} is disabled by manifest.'.format(name))
+ return
+ logging.info('Start test case: %s', name)
+
+ def tearDown(self):
+ logging.info('End test case: %s', self._testMethodName)
+
+ @contextlib.contextmanager
+ def test_session(self):
+ """Custom implementation of test_session() for XLA tests.
+
+ We override the standard Tensorflow test_session() since it is too
+ specific to CPU and GPU tests. In particular, we want to disable soft
+ placement and explicitly assign ops to devices under test.
+
+ Yields:
+ A session to use when running a test case.
+ """
+ graph = ops.Graph()
+ with session.Session(graph=graph) as sess, graph.as_default():
+ yield sess
+
+ @contextlib.contextmanager
+ def test_scope(self):
+ """Test scope that runs tests on a Tensorflow/XLA device.
+
+ Uses a compilation_scope() to mark operators to compile.
+
+ Yields:
+ A scope to apply to the operators under test.
+ """
+ with ops.device('device:{}:0'.format(self.device)):
+ yield
+
+
+def Benchmark(tf_bench, builder_fn, use_xla_jit, device):
+ """Build a graph and run benchmarks against it, with or without XLA.
+
+ Args:
+ tf_bench: An instance of tf.test.Benchmark, used to run the benchmark.
+ builder_fn: A function that builds a graph when invoked, and returns
+ (name, fetches), where name is the name of the test, and fetches
+ is a list of tensors to fetch as output.
+ use_xla_jit: If true compile with the XLA JIT, otherwise use regular TF.
+ device: The tensorflow device to run on, e.g. "cpu", "gpu".
+ """
+
+ with ops.Graph().as_default():
+ name = None
+ targets = []
+ with ops.device(device):
+ fetches = []
+ jit_scope = jit.experimental_jit_scope
+ with jit_scope(compile_ops=use_xla_jit):
+ name, fetches = builder_fn()
+
+ # We only want to benchmark the operations themselves, and not the data
+ # transfer of the result(s). Non-compiled identity ops ensure XLA
+ # doesn't know we're dropping the results, otherwise it might compile
+ # away the entire computation.
+ for fetch in fetches:
+ targets.append(array_ops.identity(fetch).op)
+
+ config = config_pb2.ConfigProto(allow_soft_placement=True)
+ with session.Session(config=config) as sess:
+ sess.run(variables.global_variables_initializer())
+ xla = 'xla_' if use_xla_jit else ''
+ tf_bench.run_op_benchmark(
+ sess, targets, name='%s_%s%s' % (name, xla, device))
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
new file mode 100644
index 0000000000..3de9958cd6
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -0,0 +1,193 @@
+licenses(["notice"]) # Apache 2.0
+
+package_group(
+ name = "internal",
+ packages = [
+ "//tensorflow/compiler/aot/...",
+ "//tensorflow/compiler/jit/...",
+ "//tensorflow/compiler/tests/...",
+ "//tensorflow/compiler/tf2xla/...",
+ ],
+)
+
+package_group(
+ name = "friends",
+ includes = [":internal"],
+ packages = ["//tensorflow/..."],
+)
+
+package(
+ default_visibility = [":internal"],
+)
+
+cc_library(
+ name = "xla_compiler",
+ srcs = [
+ "op_registrations.cc",
+ "xla_compilation_device.cc",
+ "xla_compiler.cc",
+ "xla_context.cc",
+ "xla_helpers.cc",
+ "xla_op_kernel.cc",
+ ],
+ hdrs = [
+ "xla_compilation_device.h",
+ "xla_compiler.h",
+ "xla_context.h",
+ "xla_helpers.h",
+ "xla_op_kernel.h",
+ ],
+ deps = [
+ ":common",
+ ":dump_graph",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core/kernels:cwise_op",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "common",
+ srcs = [
+ "literal_util.cc",
+ "shape_util.cc",
+ "str_util.cc",
+ "type_util.cc",
+ ],
+ hdrs = [
+ "literal_util.h",
+ "shape_util.h",
+ "str_util.h",
+ "type_util.h",
+ ],
+ visibility = [":friends"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+# Internal targets below this point.
+
+cc_test(
+ name = "str_util_test",
+ srcs = [
+ "str_util_test.cc",
+ ],
+ deps = [
+ ":common",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "literal_util_test",
+ srcs = [
+ "literal_util_test.cc",
+ ],
+ deps = [
+ ":common",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
+ name = "const_analysis",
+ srcs = ["const_analysis.cc"],
+ hdrs = ["const_analysis.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_test(
+ name = "const_analysis_test",
+ size = "small",
+ srcs = ["const_analysis_test.cc"],
+ deps = [
+ ":const_analysis",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
+ name = "xla_local_runtime_context",
+ hdrs = ["xla_local_runtime_context.h"],
+ visibility = ["//visibility:public"],
+ deps = ["//tensorflow/core:framework_lite"],
+)
+
+cc_library(
+ name = "dump_graph",
+ srcs = [
+ "dump_graph.cc",
+ "dump_graph_flags.cc",
+ "dump_graph_flags.h",
+ ],
+ hdrs = [
+ "dump_graph.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla/legacy_flags:parse_flags_from_env",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
new file mode 100644
index 0000000000..e072ef7be7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -0,0 +1,139 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/algorithm.h"
+
+namespace tensorflow {
+
+// Backwards dataflow analysis that finds arguments to a graph that must be
+// compile-time constants.
+Status BackwardsConstAnalysis(const Graph& g,
+ std::vector<bool>* compile_time_const_args) {
+ // TODO(phawkins): annotate these on the kernel registrations, rather than
+ // using a hard-coded list.
+ // (operator, argument) pairs that must be compile-time constants.
+ const std::unordered_multimap<string, string> compile_time_const_inputs = {
+ {"All", "reduction_indices"},
+ {"Any", "reduction_indices"},
+ {"ArgMax", "dimension"},
+ {"AvgPoolGrad", "orig_input_shape"},
+ {"BroadcastGradientArgs", "s0"},
+ {"BroadcastGradientArgs", "s1"},
+ {"Concat", "concat_dim"},
+ {"ConcatV2", "axis"},
+ {"ConcatOffset", "concat_dim"},
+ {"ConcatOffset", "shape"},
+ {"Conv2DBackpropFilter", "filter_sizes"},
+ {"Conv2DBackpropInput", "input_sizes"},
+ {"DynamicStitch", "indices"},
+ {"ExpandDims", "dim"},
+ {"Fill", "dims"},
+ {"InvertPermutation", "x"},
+ {"LinSpace", "start"},
+ {"LinSpace", "stop"},
+ {"LinSpace", "num"},
+ {"Max", "reduction_indices"},
+ {"Mean", "reduction_indices"},
+ {"Min", "reduction_indices"},
+ {"Pad", "paddings"},
+ {"Prod", "reduction_indices"},
+ {"RandomStandardNormal", "shape"},
+ {"RandomUniform", "shape"},
+ {"RandomUniformInt", "shape"},
+ {"Range", "start"},
+ {"Range", "limit"},
+ {"Range", "delta"},
+ {"Reshape", "shape"},
+ {"Slice", "begin"},
+ {"Slice", "size"},
+ {"Split", "split_dim"},
+ {"SplitV", "split_dim"},
+ {"SplitV", "size_splits"},
+ {"StridedSlice", "begin"},
+ {"StridedSlice", "end"},
+ {"StridedSlice", "strides"},
+ {"StridedSliceGrad", "shape"},
+ {"StridedSliceGrad", "begin"},
+ {"StridedSliceGrad", "end"},
+ {"StridedSliceGrad", "strides"},
+ {"Sum", "reduction_indices"},
+ {"Tile", "multiples"},
+ {"Transpose", "perm"}};
+
+ // Operators that don't look at the data of their inputs, just the shapes.
+ const std::unordered_set<string> metadata_ops = {
+ "Rank", "Shape", "ShapeN", "Size",
+ };
+
+ Status status;
+ std::unordered_set<Node*> must_be_const;
+ auto visit = [&status, &metadata_ops, &compile_time_const_inputs,
+ &must_be_const, compile_time_const_args](Node* node) {
+ if (!status.ok()) return;
+
+ // If this is a metadata-only op, don't propagate the const requirement.
+ if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return;
+
+ // If this node must be const, and it isn't a metadata op, then all of its
+ // parents must be const.
+ if (must_be_const.find(node) != must_be_const.end()) {
+ if (node->type_string() == "_Arg") {
+ int index;
+ status = GetNodeAttr(node->def(), "index", &index);
+ if (!status.ok()) return;
+ compile_time_const_args->at(index) = true;
+ return;
+ }
+ for (Node* pred : node->in_nodes()) {
+ must_be_const.insert(pred);
+ }
+ return;
+ }
+
+ // Mark any compile-time constant operator arguments as const.
+ auto range = compile_time_const_inputs.equal_range(node->type_string());
+ if (range.first == range.second) return;
+
+ NameRangeMap input_name_ranges;
+ status = NameRangesForNode(node->def(), node->op_def(), &input_name_ranges,
+ nullptr);
+ if (!status.ok()) return;
+
+ for (auto it = range.first; it != range.second; ++it) {
+ auto name_range = input_name_ranges.find(it->second);
+ if (name_range == input_name_ranges.end()) continue;
+
+ for (Edge const* edge : node->in_edges()) {
+ if (edge->dst_input() >= name_range->second.first &&
+ edge->dst_input() < name_range->second.second) {
+ must_be_const.insert(edge->src());
+ }
+ }
+ }
+ };
+
+ // Post-order traversal visits nodes in reverse topological order for an
+ // acyclic graph.
+ DFS(g, {}, visit);
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/const_analysis.h b/tensorflow/compiler/tf2xla/const_analysis.h
new file mode 100644
index 0000000000..634b97d7e3
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/const_analysis.h
@@ -0,0 +1,33 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_
+
+#include <vector>
+
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Backwards dataflow analysis that finds arguments (_Arg nodes) to a graph that
+// must be compile-time constants.
+Status BackwardsConstAnalysis(const Graph& graph,
+ std::vector<bool>* compile_time_const_args);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_CONST_ANALYSIS_H_
diff --git a/tensorflow/compiler/tf2xla/const_analysis_test.cc b/tensorflow/compiler/tf2xla/const_analysis_test.cc
new file mode 100644
index 0000000000..9d125f8d49
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/const_analysis_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests for the backward const analysis.
+
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
+
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+TEST(ConstAnalysisTest, Basics) {
+ Scope root = Scope::NewRootScope();
+
+ auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
+ auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
+ auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
+ auto arg3 = ops::_Arg(root.WithOpName("Arg3"), DT_INT32, 3);
+ auto a = ops::Shape(root, arg0);
+ auto b = ops::Add(root, a, arg1);
+ auto c = ops::Reshape(root, arg2, b);
+ auto d = ops::Mul(root, c, ops::Sum(root, arg3, arg3));
+
+ Graph graph(OpRegistry::Global());
+ TF_ASSERT_OK(root.ToGraph(&graph));
+
+ std::vector<bool> const_args(4, false);
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+
+ // Arg 0 doesn't need to be constant since the graph only uses its shape.
+ // Arg 1 must be constant because it flows to the shape argument of a Reshape.
+ // Arg 2 is used only as the value input to a Reshape and need not be const.
+ // Arg 3 is used as the reduction-indices argument to Sum and must be const.
+ EXPECT_EQ(const_args, std::vector<bool>({false, true, false, true}));
+}
+
+// Regression test for a case where the backward const analysis did
+// not visit nodes in topological order.
+TEST(ConstAnalysisTest, TopologicalOrder) {
+ for (bool order : {false, true}) {
+ Scope root = Scope::NewRootScope();
+
+ auto arg0 = ops::_Arg(root.WithOpName("Arg0"), DT_INT32, 0);
+ auto arg1 = ops::_Arg(root.WithOpName("Arg1"), DT_INT32, 1);
+ auto arg2 = ops::_Arg(root.WithOpName("Arg2"), DT_INT32, 2);
+ auto a = ops::Reshape(root, arg0, arg1);
+ auto b = ops::Reshape(root, arg2, a);
+ if (order) {
+ // Consider both orders for arguments to the Sum so we aren't sensitive
+ // to the DFS traversal order.
+ std::swap(a, b);
+ }
+ auto c = ops::Add(root, a, b);
+
+ Graph graph(OpRegistry::Global());
+ TF_ASSERT_OK(root.ToGraph(&graph));
+
+ std::vector<bool> const_args(3, false);
+ TF_ASSERT_OK(BackwardsConstAnalysis(graph, &const_args));
+
+ EXPECT_EQ(const_args, std::vector<bool>({true, true, false}));
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
new file mode 100644
index 0000000000..5aa6f806ac
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for
+// debugging.
+
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+
+#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace dump_graph {
+
+namespace {
+
+struct NameCounts {
+ mutex counts_mutex;
+ std::unordered_map<string, int> counts;
+};
+
+string MakeUniquePath(const string& name) {
+ static NameCounts& instance = *new NameCounts;
+ int count;
+ {
+ mutex_lock lock(instance.counts_mutex);
+ count = instance.counts[name]++;
+ }
+
+ legacy_flags::DumpGraphFlags* flags = legacy_flags::GetDumpGraphFlags();
+ string path = strings::StrCat(flags->tf_dump_graph_prefix, "/", name);
+ if (count > 0) {
+ strings::StrAppend(&path, "_", count);
+ }
+ strings::StrAppend(&path, ".pbtxt");
+ return path;
+}
+
+} // anonymous namespace
+
+string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) {
+ string path = MakeUniquePath(name);
+ TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def));
+ return path;
+}
+
+string DumpGraphToFile(const string& name, Graph const& graph,
+ const FunctionLibraryDefinition* flib_def) {
+ GraphDef graph_def;
+ graph.ToGraphDef(&graph_def);
+ if (flib_def) {
+ *graph_def.mutable_library() = flib_def->ToProto();
+ }
+ return DumpGraphDefToFile(name, graph_def);
+}
+
+string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) {
+ string path = MakeUniquePath(name);
+ TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef));
+ return path;
+}
+
+} // namespace dump_graph
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/dump_graph.h b/tensorflow/compiler/tf2xla/dump_graph.h
new file mode 100644
index 0000000000..bbf01eb90d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/dump_graph.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helper functions for dumping Graphs, GraphDefs, and FunctionDefs to files for
+// debugging.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_
+#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace dump_graph {
+
+// Dumps 'graph_def' to a file, as a GraphDef text proto. Returns the file name
+// chosen.
+//
+// Automatically picks a file name. Prefixes 'name' with the value of the
+// --tf_dump_graph_prefix flag and suffixes it with ".pbtxt" to form a name.
+// If a graph has already been dumped by this process with the same name,
+// suffixes with "_n.pbtxt", where 'n' is a sequence number.
+string DumpGraphDefToFile(const string& name, GraphDef const& graph_def);
+
+// Similar to DumpGraphDefToFile, but builds the GraphDef to dump from a 'graph'
+// and an optional function library 'flib_def'. Returns the file name chosen.
+string DumpGraphToFile(const string& name, Graph const& graph,
+ const FunctionLibraryDefinition* flib_def = nullptr);
+
+// Similar to DumpGraphDefToFile, but dumps a function as a FunctionDef text
+// proto. Returns the file name chosen.
+string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef);
+
+} // namespace dump_graph
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_H_
diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.cc b/tensorflow/compiler/tf2xla/dump_graph_flags.cc
new file mode 100644
index 0000000000..a6c908ba01
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/dump_graph_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for the XLA bridge's dump_graph module.
+
+#include <mutex>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/dump_graph_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static DumpGraphFlags* flags;
+static std::vector<Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new DumpGraphFlags;
+ flags->tf_dump_graph_prefix = "/tmp/";
+ flag_list = new std::vector<Flag>({
+ Flag("tf_dump_graph_prefix", &flags->tf_dump_graph_prefix,
+ "Path prefix to which graphs dumped during debugging should be "
+ "written."),
+ });
+ xla::legacy_flags::ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with the XLA bridge's
+// dump_graph module.
+void AppendDumpGraphFlags(std::vector<Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the DumpGraphFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+DumpGraphFlags* GetDumpGraphFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/dump_graph_flags.h b/tensorflow/compiler/tf2xla/dump_graph_flags.h
new file mode 100644
index 0000000000..80a3307d92
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/dump_graph_flags.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_
+
+// Legacy flags for the XLA bridge's dump_graph module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with the XLA bridge's
+// dump_graph module.
+void AppendDumpGraphFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with the XLA bridge's
+// dump_graph module.
+typedef struct {
+ string tf_dump_graph_prefix; // Path prefix to which graphs dumped during
+ // debugging should be written.
+} DumpGraphFlags;
+
+// Return a pointer to the DumpGraphFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+DumpGraphFlags* GetDumpGraphFlags();
+
+} // namespace legacy_flags
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_DUMP_GRAPH_FLAGS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
new file mode 100644
index 0000000000..d913f898e9
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -0,0 +1,177 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [
+ "//tensorflow/compiler/tf2xla:internal",
+ ],
+ features = ["no_layering_check"],
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts")
+
+tf_kernel_library(
+ name = "xla_ops",
+ srcs = [
+ "aggregate_ops.cc",
+ "batch_matmul_op.cc",
+ "bcast_ops.cc",
+ "bias_ops.cc",
+ "binary_ops.cc",
+ "cast_op.cc",
+ "concat_op.cc",
+ "conv_ops.cc",
+ "cwise_ops.cc",
+ "declaration_op.cc",
+ "depthwise_conv_ops.cc",
+ "diag_op.cc",
+ "dynamic_stitch_op.cc",
+ "fill_op.cc",
+ "function_ops.cc",
+ "identity_op.cc",
+ "l2loss_op.cc",
+ "lrn_ops.cc",
+ "matmul_op.cc",
+ "no_op.cc",
+ "pack_op.cc",
+ "pad_op.cc",
+ "pooling_ops.cc",
+ "random_ops.cc",
+ "reduction_ops.cc",
+ "reduction_ops_common.cc",
+ "relu_op.cc",
+ "reshape_op.cc",
+ "retval_op.cc",
+ "select_op.cc",
+ "sequence_ops.cc",
+ "shape_op.cc",
+ "slice_op.cc",
+ "softmax_op.cc",
+ "split_op.cc",
+ "strided_slice_op.cc",
+ "tile_ops.cc",
+ "transpose_op.cc",
+ "unary_ops.cc",
+ "unpack_op.cc",
+ ],
+ hdrs = [
+ "cwise_ops.h",
+ "reduction_ops.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core/kernels:concat_lib",
+ "//tensorflow/core/kernels:conv_2d",
+ "//tensorflow/core/kernels:conv_ops",
+ "//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:depthwise_conv_op",
+ "//tensorflow/core/kernels:matmul_op",
+ "//tensorflow/core/kernels:no_op",
+ "//tensorflow/core/kernels:ops_util",
+ "//tensorflow/core/kernels:pooling_ops",
+ "//tensorflow/core/kernels:sendrecv_ops",
+ "//tensorflow/core/kernels:transpose_op",
+ ],
+)
+
+# Kernels that only work on CPU, because they use XLA custom calls.
+# Only link this when using the CPU backend for XLA.
+#
+# TODO(cwhipkey): move into xla_ops when ops can be registered for
+# CPU compilation only (b/31363654).
+tf_kernel_library(
+ name = "xla_cpu_only_ops",
+ srcs = [
+ "gather_op.cc",
+ "index_ops.cc",
+ ],
+ deps = [
+ ":gather_op_kernel_float_int32",
+ ":gather_op_kernel_float_int64",
+ ":index_ops_kernel_argmax_float_1d",
+ ":index_ops_kernel_argmax_float_2d",
+ "//tensorflow/compiler/tf2xla:common",
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ ],
+)
+
+tf_kernel_library(
+ name = "gather_op_kernel_float_int32",
+ srcs = ["gather_op_kernel_float_int32.cc"],
+ # Makes the custom-call function visible to LLVM during JIT.
+ linkopts = export_dynamic_linkopts,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core/kernels:gather_functor",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "gather_op_kernel_float_int64",
+ srcs = ["gather_op_kernel_float_int64.cc"],
+ # Makes the custom-call function visible to LLVM during JIT.
+ linkopts = export_dynamic_linkopts,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/tf2xla:xla_local_runtime_context",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core/kernels:gather_functor",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "index_ops_kernel_argmax_float_1d",
+ srcs = ["index_ops_kernel_argmax_float_1d.cc"],
+ # Makes the custom-call function visible to LLVM during JIT.
+ linkopts = export_dynamic_linkopts,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "index_ops_kernel_argmax_float_2d",
+ srcs = ["index_ops_kernel_argmax_float_2d.cc"],
+ # Makes the custom-call function visible to LLVM during JIT.
+ linkopts = export_dynamic_linkopts,
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
new file mode 100644
index 0000000000..8f284c3017
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/aggregate_ops.cc
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class AddNOp : public XlaOpKernel {
+ public:
+ explicit AddNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ if (!ctx->ValidateInputsAreSameShape(this)) return;
+
+ OP_REQUIRES(ctx, ctx->num_inputs() >= 1,
+ errors::InvalidArgument("AddN requires at least one argument"));
+
+ xla::ComputationDataHandle sum = ctx->Input(0);
+ for (int i = 1; i < ctx->num_inputs(); ++i) {
+ sum = ctx->builder()->Add(sum, ctx->Input(i));
+ }
+
+ ctx->SetOutput(0, sum);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(AddNOp);
+};
+
+REGISTER_XLA_OP("AddN", AddNOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
new file mode 100644
index 0000000000..637360d149
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific BatchMatMul Op.
+// The current implementation simply unrolls the computation along the batch
+// dimension.
+// TODO(dominikg,phawkins): Use a real batched matmul instead of unrolling.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class BatchMatMulOp : public XlaOpKernel {
+ public:
+ explicit BatchMatMulOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_x", &adj_x_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("adj_y", &adj_y_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape x_shape = ctx->InputShape(0);
+ const TensorShape y_shape = ctx->InputShape(1);
+
+ // Check that both tensors have the same number of dimensions. There must be
+ // at least two (the batch dimensions can be empty).
+ OP_REQUIRES(ctx, x_shape.dims() == y_shape.dims(),
+ errors::InvalidArgument("In[0] and In[1] has different ndims: ",
+ x_shape.DebugString(), " vs. ",
+ y_shape.DebugString()));
+ const int ndims = x_shape.dims();
+ OP_REQUIRES(
+ ctx, ndims >= 2,
+ errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ", ndims));
+
+ // The batch dimensions must be equal and the matrix dimensions must be
+ // valid.
+ std::vector<int64> dimensions;
+ int batch_count = 1;
+ for (int i = 0; i < ndims - 2; ++i) {
+ OP_REQUIRES(
+ ctx, x_shape.dim_size(i) == y_shape.dim_size(i),
+ errors::InvalidArgument("In[0].dim(", i, ") and In[1].dim(", i,
+ ") must be the same: ", x_shape.DebugString(),
+ " vs ", y_shape.DebugString()));
+ dimensions.push_back(x_shape.dim_size(i));
+ batch_count *= x_shape.dim_size(i);
+ }
+
+ int x_inner_dim = adj_x_ ? (ndims - 2) : (ndims - 1);
+ int y_inner_dim = adj_y_ ? (ndims - 1) : (ndims - 2);
+ OP_REQUIRES(
+ ctx, x_shape.dim_size(x_inner_dim) == y_shape.dim_size(y_inner_dim),
+ errors::InvalidArgument(
+ "In[0] mismatch In[1] shape: ", x_shape.dim_size(x_inner_dim),
+ " vs. ", y_shape.dim_size(y_inner_dim), ": ", x_shape.DebugString(),
+ " ", y_shape.DebugString(), " ", adj_x_, " ", adj_y_));
+
+ int x_outer_dim = adj_x_ ? (ndims - 1) : (ndims - 2);
+ int y_outer_dim = adj_y_ ? (ndims - 2) : (ndims - 1);
+ dimensions.push_back(x_shape.dim_size(x_outer_dim));
+ dimensions.push_back(y_shape.dim_size(y_outer_dim));
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ xla::ComputationDataHandle x_handle = ctx->Input(0);
+ xla::ComputationDataHandle y_handle = ctx->Input(1);
+
+ // Reshape input tensors into 3D tensors by flattening the batch
+ // dimensions. This makes it easier to unroll the batch dimension.
+ auto x_flat =
+ builder->Reshape(x_handle, {batch_count, x_shape.dim_size(ndims - 2),
+ x_shape.dim_size(ndims - 1)});
+ auto y_flat =
+ builder->Reshape(y_handle, {batch_count, y_shape.dim_size(ndims - 2),
+ y_shape.dim_size(ndims - 1)});
+
+ // Slice batches into individual matrices and multiply them.
+ std::vector<xla::ComputationDataHandle> out_slices;
+ for (int i = 0; i < batch_count; ++i) {
+ // Slice off individual matrices and reshape to 2D tensors.
+ auto x_slice = builder->Slice(
+ x_flat, {i, 0, 0},
+ {i + 1, x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
+ x_slice = builder->Reshape(
+ x_slice, {x_shape.dim_size(ndims - 2), x_shape.dim_size(ndims - 1)});
+ auto y_slice = builder->Slice(
+ y_flat, {i, 0, 0},
+ {i + 1, y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
+ y_slice = builder->Reshape(
+ y_slice, {y_shape.dim_size(ndims - 2), y_shape.dim_size(ndims - 1)});
+
+ // Transpose if needed.
+ auto lhs = adj_x_ ? builder->Transpose(x_slice, {1, 0}) : x_slice;
+ auto rhs = adj_y_ ? builder->Transpose(y_slice, {1, 0}) : y_slice;
+
+ // Multiply matrices and add an outer singleton dimension to the output
+ // so we can concatenate along the flattened batch dimension later.
+ auto out = builder->Dot(lhs, rhs);
+ out = builder->Reshape(out,
+ {1, dimensions[ndims - 2], dimensions[ndims - 1]});
+ out_slices.push_back(out);
+ }
+
+ // Concatenate output slices and reshape to original number of dimensions.
+ xla::ComputationDataHandle data;
+ if (out_slices.empty()) {
+ // It is illegal to pass an empty list to ConcatInDim.
+ // The batch count is empty, so both inputs must have zero elements.
+ // Arbitrarily use the left input as the argument to Reshape().
+ data = x_handle;
+ } else {
+ data = builder->ConcatInDim(out_slices, 0);
+ }
+ data = builder->Reshape(data, dimensions);
+
+ ctx->SetOutput(0, data);
+ }
+
+ private:
+ bool adj_x_;
+ bool adj_y_;
+};
+
+REGISTER_XLA_OP("BatchMatMul", BatchMatMulOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
new file mode 100644
index 0000000000..f35835df08
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for broadcasting used in gradient
+// code.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+namespace {
+
+// Given shapes of two tensors, computes the reduction indices for the
+// gradient computation.
+//
+// TODO(zhifengc):
+// 1. Adds support for n-ary (n >= 2).
+class BCastGradArgsOp : public XlaOpKernel {
+ public:
+ explicit BCastGradArgsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(
+ ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() == 2,
+ errors::Unimplemented("Broadcast for n-ary operations (n > 2)"));
+
+ gtl::InlinedVector<BCast::Vec, 4> shapes;
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const TensorShape in_shape = ctx->InputShape(i);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape),
+ errors::InvalidArgument("In[", i, "] must be a vector.",
+ in_shape.DebugString()));
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(i, &literal));
+
+ BCast::Vec vec;
+ for (int64 i = 0; i < in_shape.num_elements(); ++i) {
+ vec.push_back(xla::LiteralUtil::Get<int>(literal, {i}));
+ }
+ shapes.push_back(vec);
+ }
+ BCast bcast(shapes[0], shapes[1]);
+ OP_REQUIRES(ctx, bcast.IsValid(),
+ errors::InvalidArgument(
+ "Incompatible shapes: [", str_util::Join(shapes[0], ","),
+ "] vs. [", str_util::Join(shapes[1], ","), "]"));
+ Output(ctx, 0, bcast.grad_x_reduce_idx());
+ Output(ctx, 1, bcast.grad_y_reduce_idx());
+ }
+
+ private:
+ void Output(XlaOpKernelContext* ctx, int idx, const BCast::Vec& v) {
+ const int64 len = v.size();
+ Tensor constant(DT_INT32, TensorShape({len}));
+ for (int64 i = 0; i < len; ++i) {
+ constant.flat<int32>()(i) = static_cast<int32>(v[i]);
+ }
+ ctx->SetConstantOutput(idx, constant);
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp);
+};
+
+REGISTER_XLA_OP("BroadcastGradientArgs", BCastGradArgsOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/bias_ops.cc b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
new file mode 100644
index 0000000000..217e82304e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/bias_ops.cc
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+class BiasOp : public XlaOpKernel {
+ public:
+ explicit BiasOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ if (ctx->GetAttr("data_format", &data_format).ok()) {
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ } else {
+ data_format_ = FORMAT_NHWC;
+ }
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape bias_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrixOrHigher(input_shape),
+ errors::InvalidArgument("Input tensor must be at least 2D: ",
+ input_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_shape),
+ errors::InvalidArgument("Biases must be 1D: ",
+ bias_shape.DebugString()));
+ int feature_dim = (data_format_ == FORMAT_NHWC) ? input_shape.dims() - 1
+ : input_shape.dims() - 3;
+ OP_REQUIRES(
+ ctx, feature_dim >= 0,
+ errors::InvalidArgument("Input tensor does not have enough dimensions "
+ "to contain the feature dimension"));
+ OP_REQUIRES(
+ ctx, bias_shape.dim_size(0) == input_shape.dim_size(feature_dim),
+ errors::InvalidArgument(
+ "Must provide as many biases as the last dimension "
+ "of the input tensor: ",
+ bias_shape.DebugString(), " vs. ", input_shape.DebugString()));
+
+ xla::ComputationDataHandle result =
+ ctx->builder()->Add(ctx->Input(0), ctx->Input(1), {feature_dim});
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("BiasAdd", BiasOp);
+REGISTER_XLA_OP("BiasAddV1", BiasOp);
+
+class BiasAddGradOp : public XlaOpKernel {
+ public:
+ explicit BiasAddGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ if (ctx->GetAttr("data_format", &data_format).ok()) {
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ } else {
+ data_format_ = FORMAT_NHWC;
+ }
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape out_backprop_shape = ctx->InputShape(0);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrixOrHigher(out_backprop_shape),
+ errors::InvalidArgument("Input tensor must be at least 2D: ",
+ out_backprop_shape.DebugString()));
+
+ int feature_dim = (data_format_ == FORMAT_NHWC)
+ ? out_backprop_shape.dims() - 1
+ : out_backprop_shape.dims() - 3;
+ OP_REQUIRES(
+ ctx, feature_dim >= 0,
+ errors::InvalidArgument("Input tensor does not have enough dimensions "
+ "to contain the feature dimension"));
+
+ std::vector<int64> reduce_dims(out_backprop_shape.dims() - 1);
+ std::iota(reduce_dims.begin(), reduce_dims.begin() + feature_dim, 0);
+ std::iota(reduce_dims.begin() + feature_dim, reduce_dims.end(),
+ feature_dim + 1);
+ xla::ComputationDataHandle result = ctx->builder()->Reduce(
+ ctx->Input(0), XlaHelpers::Zero(ctx->builder(), input_type(0)),
+ *ctx->GetOrCreateAdd(input_type(0)), reduce_dims);
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("BiasAddGrad", BiasAddGradOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
new file mode 100644
index 0000000000..6f117ebe61
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -0,0 +1,158 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Native XLA implementations of simple unary Ops
+
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+// A subclass of a XlaBinaryOp must build the computation that
+// describes the (tensor,tensor)->tensor function to apply to each element of
+// the input.
+#define XLA_MAKE_BINARY(Name, HLO) \
+ class Name##Op : public XlaBinaryOp { \
+ public: \
+ explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {} \
+ xla::ComputationDataHandle Computation( \
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs, \
+ const gtl::ArraySlice<int64>& lhs_shape, \
+ const xla::ComputationDataHandle& rhs, \
+ const gtl::ArraySlice<int64>& rhs_shape, \
+ const BCast& broadcast_helper, \
+ const std::vector<int64>& extend_dimensions) override { \
+ xla::ComputationBuilder* b = ctx->builder(); \
+ return HLO; \
+ } \
+ }; \
+ REGISTER_XLA_OP(#Name, Name##Op)
+
+XLA_MAKE_BINARY(Add, b->Add(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions));
+
+// Implementation of FloorDiv. Pseudo-code:
+// if ((x < 0) != (y < 0)) {
+// T abs_x = std::abs(x);
+// T abs_y = std::abs(y);
+// return -(abs_x + abs_y - 1) / abs_y;
+// } else {
+// return x / y;
+// }
+static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b,
+ DataType dtype,
+ xla::ComputationDataHandle x,
+ xla::ComputationDataHandle y,
+ const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto one = XlaHelpers::One(b, dtype);
+ auto different_sign = b->Ne(b->Lt(x, zero), b->Lt(y, zero));
+ auto abs_x = b->Abs(x);
+ auto abs_y = b->Abs(y);
+ auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one));
+ auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y));
+ if (dtype == DT_FLOAT || dtype == DT_DOUBLE) {
+ result = b->Floor(result);
+ }
+ return result;
+}
+XLA_MAKE_BINARY(FloorDiv,
+ FloorDivImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
+// Implementation of FloorMod. Pseudo-code:
+// T trunc_mod = std::fmod(x, y);
+// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
+static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b,
+ DataType dtype,
+ xla::ComputationDataHandle x,
+ xla::ComputationDataHandle y,
+ const BCast& broadcast_helper) {
+ std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto same_sign = b->Eq(b->Lt(x, zero), b->Lt(y, zero));
+ auto trunc_mod = b->Rem(x, y);
+ return b->Select(same_sign, trunc_mod, b->Rem(b->Add(trunc_mod, y), y));
+}
+XLA_MAKE_BINARY(FloorMod,
+ FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper));
+
+XLA_MAKE_BINARY(LogicalAnd, b->LogicalAnd(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LogicalOr, b->LogicalOr(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(RealDiv, b->Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(
+ RsqrtGrad,
+ b->Mul(b->Pow(lhs, XlaHelpers::IntegerLiteral(b, input_type(0), 3)),
+ b->Div(rhs, XlaHelpers::IntegerLiteral(b, input_type(0), -2)),
+ extend_dimensions));
+
+static xla::ComputationDataHandle Square(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& x) {
+ return builder->Mul(x, x);
+}
+
+XLA_MAKE_BINARY(SquaredDifference,
+ Square(b, b->Sub(lhs, rhs, extend_dimensions)));
+
+XLA_MAKE_BINARY(TruncateDiv, b->Div(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(TruncateMod, b->Rem(lhs, rhs, extend_dimensions));
+
+// Comparison ops
+XLA_MAKE_BINARY(Equal, b->Eq(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(NotEqual, b->Ne(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Greater, b->Gt(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(GreaterEqual, b->Ge(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(Less, b->Lt(lhs, rhs, extend_dimensions));
+XLA_MAKE_BINARY(LessEqual, b->Le(lhs, rhs, extend_dimensions));
+
+#undef XLA_MAKE_BINARY
+
+#define XLA_MAKE_BINARY_MAP(Name, HLO) \
+ class Name##Op : public XlaBinaryMapOp { \
+ public: \
+ explicit Name##Op(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {} \
+ void BuildMapLambda(xla::ComputationBuilder* b, \
+ const xla::ComputationDataHandle& lhs, \
+ const xla::ComputationDataHandle& rhs) override { \
+ HLO; \
+ } \
+ }; \
+ REGISTER_XLA_OP(#Name, Name##Op)
+
+XLA_MAKE_BINARY_MAP(Pow, b->Pow(lhs, rhs));
+XLA_MAKE_BINARY_MAP(SigmoidGrad,
+ b->Mul(b->Mul(rhs, lhs),
+ b->Sub(XlaHelpers::One(b, input_type(0)), lhs)));
+XLA_MAKE_BINARY_MAP(SoftplusGrad,
+ b->Div(lhs, b->Add(b->Exp(b->Neg(rhs)),
+ XlaHelpers::One(b, input_type(1)))));
+XLA_MAKE_BINARY_MAP(TanhGrad,
+ b->Mul(rhs, b->Sub(XlaHelpers::One(b, input_type(0)),
+ b->Mul(lhs, lhs))));
+
+#undef XLA_MAKE_BINARY_MAP
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
new file mode 100644
index 0000000000..b0188b4f8d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+class CastOp : public XlaOpKernel {
+ public:
+ explicit CastOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("SrcT", &src_dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("DstT", &dst_dtype_));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(src_dtype_, &src_type_));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dst_dtype_, &dst_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+ xla::ComputationDataHandle input = ctx->Input(0);
+ xla::ComputationDataHandle output;
+
+ if (src_dtype_ == dst_dtype_) {
+ output = input;
+ } else if (src_dtype_ == DT_BOOL) {
+ // XLA's ConvertElementType doesn't support casting to/from
+ // bools. So we need to handle those cases separately.
+ // Builds the equivalent of (input ? 1 : 0)
+ xla::ComputationBuilder l(builder->client(), "PredCast");
+ xla::ComputationDataHandle x =
+ l.Parameter(0, xla::ShapeUtil::MakeShape(src_type_, {}), "x");
+ l.Select(x, XlaHelpers::One(&l, dst_dtype_),
+ XlaHelpers::Zero(&l, dst_dtype_));
+ xla::Computation computation = l.Build().ConsumeValueOrDie();
+ output = builder->Map({input}, computation);
+ } else if (dst_dtype_ == DT_BOOL) {
+ output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_));
+ } else {
+ output = builder->ConvertElementType(input, dst_type_);
+ }
+
+ ctx->SetOutput(0, output);
+ }
+
+ protected:
+ DataType src_dtype_, dst_dtype_;
+ xla::PrimitiveType src_type_, dst_type_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CastOp);
+};
+
+REGISTER_XLA_OP("Cast", CastOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
new file mode 100644
index 0000000000..96ef2ac20c
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Concat Ops.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+// --------------------------------------------------------------------------
+class ConcatBaseOp : public XlaOpKernel {
+ public:
+ ConcatBaseOp(OpKernelConstruction* c, int axis_index)
+ : XlaOpKernel(c), axis_index_(axis_index) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_);
+ OP_REQUIRES(
+ ctx, IsLegacyScalar(concat_dim_tensor_shape),
+ errors::InvalidArgument(
+ "Concat dim tensor should be a scalar integer, but got shape ",
+ concat_dim_tensor_shape.DebugString()));
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal));
+ // TODO(annarev): add a helper to support int64 input.
+ const int32 concat_dim = xla::LiteralUtil::Get<int>(literal, {});
+
+ std::vector<xla::ComputationDataHandle> values;
+ std::vector<TensorShape> shapes;
+ OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
+ const int N = values.size();
+ const int input_dims = shapes[0].dims();
+ const TensorShape& input_shape = shapes[0];
+
+ int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
+ OP_REQUIRES(ctx,
+ (0 <= axis && axis < input_dims) ||
+ (allow_legacy_scalars() && concat_dim == 0),
+ errors::InvalidArgument(
+ "ConcatOp : Expected concatenating dimensions in the range "
+ "[",
+ -input_dims, ", ", input_dims, "), but got ", concat_dim));
+
+ // Make a vector holding the ComputationDataHandles for each of
+ // the inputs that has non-zero elements.
+ std::vector<xla::ComputationDataHandle> input_data;
+ int output_concat_dim = 0;
+ const bool input_is_scalar = IsLegacyScalar(input_shape);
+ for (int i = 0; i < N; ++i) {
+ xla::ComputationDataHandle handle = values[i];
+ const TensorShape& in_shape = shapes[i];
+ const bool in_is_scalar = IsLegacyScalar(in_shape);
+ OP_REQUIRES(
+ ctx,
+ in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar),
+ errors::InvalidArgument(
+ "ConcatOp : Ranks of all input tensors should match: shape[0] = ",
+ input_shape.DebugString(), " vs. shape[", i, "] = ",
+ in_shape.DebugString()));
+ if (in_shape.dims() == 0) {
+ // Inputs that come in as scalars must be reshaped to 1-vectors.
+ input_data.push_back(ctx->builder()->Reshape(handle, {1}));
+ } else {
+ input_data.push_back(handle);
+ }
+ output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1;
+ }
+
+ VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis;
+ ctx->SetOutput(0, ctx->builder()->ConcatInDim(input_data, axis));
+ }
+
+ private:
+ int axis_index_;
+};
+
+class ConcatOp : public ConcatBaseOp {
+ public:
+ explicit ConcatOp(OpKernelConstruction* c)
+ : ConcatBaseOp(c, /* axis_index */ 0) {}
+};
+
+// ConcatV2 operation is the same as Concat except 'concat_dim'
+// is the last input instead of the first and renamed to 'axis'.
+class ConcatV2Op : public ConcatBaseOp {
+ public:
+ explicit ConcatV2Op(OpKernelConstruction* c)
+ : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {}
+};
+
+REGISTER_XLA_OP("Concat", ConcatOp);
+REGISTER_XLA_OP("ConcatV2", ConcatV2Op);
+
+class ConcatOffsetOp : public XlaOpKernel {
+ public:
+ explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape concat_dim_shape = ctx->InputShape(0);
+ OP_REQUIRES(
+ ctx, IsLegacyScalar(concat_dim_shape),
+ errors::InvalidArgument(
+ "Concat dim tensor should be a scalar integer, but got shape ",
+ concat_dim_shape.DebugString()));
+ for (int i = 1; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)),
+ errors::InvalidArgument("input ", i,
+ " should be a vector, but got shape ",
+ ctx->InputShape(i).DebugString()));
+ }
+ // Suppose a Concat() op needs to Concatenate N tensors, each of
+ // which has the same number of dimensions. Their shapes match
+ // except the concat dimension.
+ //
+ // E.g., say, we want to concatenate 3 tensors in the 2nd
+ // dimension, and their shapes are:
+ //
+ // [2, 2, 5, 7]
+ // [2, 3, 5, 7]
+ // [2, 4, 5, 7]
+ //
+ // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape
+ // [2,9,5,7]. We will compute the cumulative sum along the 2nd
+ // dimension to figure out each input's offset in the concatenated
+ // output:
+ // [0, 0, 0, 0]
+ // [0, 2, 0, 0]
+ // [0, 5, 0, 0]
+ const int32 N = ctx->num_inputs() - 1;
+ const TensorShape inp0_shape = ctx->InputShape(1);
+ xla::Literal inp0_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &inp0_literal));
+ const int64 dims = inp0_shape.num_elements();
+
+ xla::Literal concat_dim_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal));
+ const int64 cdim = xla::LiteralUtil::Get<int>(concat_dim_literal, {});
+
+ VLOG(1) << "ConcatOffset " << cdim << "," << dims;
+ int32 axis = cdim < 0 ? cdim + dims : cdim;
+ OP_REQUIRES(ctx, FastBoundsCheck(axis, dims),
+ errors::InvalidArgument("Concat dim is out of range: ", axis,
+ " vs. ", dims));
+ int32 offset = 0;
+ for (int i = 0; i < N; ++i) {
+ const TensorShape inp_shape = ctx->InputShape(1 + i);
+ OP_REQUIRES(ctx, dims == inp_shape.num_elements(),
+ errors::InvalidArgument("input ", i, " should contain ", dims,
+ " elements, but got",
+ inp_shape.num_elements()));
+ xla::Literal inp_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1 + i, &inp_literal));
+
+ Tensor out_constant(DT_INT32, TensorShape({dims}));
+ auto out_vec = out_constant.vec<int32>();
+ for (int64 j = 0; j < dims; ++j) {
+ if (j == axis) {
+ out_vec(j) = offset;
+ offset += xla::LiteralUtil::Get<int>(inp_literal, {j});
+ } else {
+ const int32 inp0_element =
+ xla::LiteralUtil::Get<int>(inp0_literal, {j});
+ const int32 inp_element =
+ xla::LiteralUtil::Get<int>(inp_literal, {j});
+ OP_REQUIRES(
+ ctx, (inp0_element == inp_element),
+ errors::InvalidArgument("input[", i, ",", j, "] mismatch: ",
+ inp0_element, " vs. ", inp_element));
+ out_vec(j) = 0;
+ }
+ }
+
+ ctx->SetConstantOutput(i, out_constant);
+ }
+ }
+};
+
+REGISTER_XLA_OP("ConcatOffset", ConcatOffsetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
new file mode 100644
index 0000000000..9bebfcfe47
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -0,0 +1,373 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D convolution.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+namespace {
+
+class Conv2DOp : public XlaOpKernel {
+ public:
+ explicit Conv2DOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(ctx, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
+ const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ ctx, stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ // Input filter is of the following dimensions:
+ // [ filter_rows, filter_cols, in_depth, out_depth]
+ const TensorShape filter_shape = ctx->InputShape(1);
+
+ // For 2D convolution, there should be 4 dimensions.
+ OP_REQUIRES(ctx, input_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_shape.DebugString()));
+ OP_REQUIRES(ctx, filter_shape.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter_shape.DebugString()));
+
+ // The 'C' dimension for input is in_depth. It must be the same as
+ // the filter's in_depth.
+ const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C');
+ OP_REQUIRES(
+ ctx, in_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument("input and filter must have the same depth: ",
+ in_depth, " vs ", filter_shape.dim_size(2)));
+
+ // The last dimension for filter is out_depth.
+ const int64 out_depth = filter_shape.dim_size(3);
+
+ // The 'H' dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ const int64 filter_rows = filter_shape.dim_size(0);
+
+ // The 'W' dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ const int64 filter_cols = filter_shape.dim_size(1);
+
+ // For now we take the stride from the H and W dimensions only (we
+ // do not support striding on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(ctx,
+ GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(ctx,
+ GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ VLOG(2) << "Conv2D: in_depth = " << in_depth
+ << ", input_cols = " << input_cols
+ << ", filter_cols = " << filter_cols
+ << ", input_rows = " << input_rows
+ << ", filter_rows = " << filter_rows
+ << ", stride_rows = " << stride_rows
+ << ", stride_cols = " << stride_cols
+ << ", out_depth = " << out_depth;
+
+ xla::ConvolutionDimensionNumbers dims;
+ dims.set_batch_dimension(GetTensorDimIndex<2>(data_format_, 'N'));
+ dims.set_feature_dimension(GetTensorDimIndex<2>(data_format_, 'C'));
+ dims.add_spatial_dimensions(GetTensorDimIndex<2>(data_format_, 'H'));
+ dims.add_spatial_dimensions(GetTensorDimIndex<2>(data_format_, 'W'));
+
+ // TF filter shape is [ H, W, inC, outC ]
+ dims.add_kernel_spatial_dimensions(0);
+ dims.add_kernel_spatial_dimensions(1);
+ dims.set_kernel_input_feature_dimension(2);
+ dims.set_kernel_output_feature_dimension(3);
+
+ std::vector<int64> window_strides = {stride_rows, stride_cols};
+ xla::Padding xla_padding =
+ (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ xla::ComputationDataHandle conv = ctx->builder()->ConvWithGeneralDimensions(
+ ctx->Input(0), ctx->Input(1), window_strides, xla_padding, dims);
+ ctx->SetOutput(0, conv);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
+};
+
+REGISTER_XLA_OP("Conv2D", Conv2DOp);
+
+// Backprop for input.
+class Conv2DBackpropInputOp : public XlaOpKernel {
+ public:
+ explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
+ OP_REQUIRES(ctx, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ int stride_n = GetTensorDim(strides_, data_format_, 'N');
+ int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ ctx, (stride_n == 1 && stride_c == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape input_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
+
+ const TensorShape filter_shape = ctx->InputShape(1);
+ const TensorShape out_backprop_shape = ctx->InputShape(2);
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ Conv2DBackpropDimensions dims;
+ OP_REQUIRES_OK(
+ ctx, Conv2DBackpropComputeDimensions(
+ "Conv2DBackpropInput", input_shape, filter_shape,
+ out_backprop_shape, strides_, padding_, data_format_, &dims));
+
+ auto filter = ctx->Input(1);
+ auto out_backprop = ctx->Input(2);
+
+ // The input gradients are computed by a convolution of the output
+ // gradients and the filter, with some appropriate padding. See the
+ // comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+ dnums.set_batch_dimension(GetTensorDimIndex(data_format_, 'N'));
+ dnums.add_spatial_dimensions(GetTensorDimIndex(data_format_, 'H'));
+ dnums.add_spatial_dimensions(GetTensorDimIndex(data_format_, 'W'));
+ dnums.set_feature_dimension(GetTensorDimIndex(data_format_, 'C'));
+
+ // TF filter shape is [ H, W, inC, outC ]
+ // Transpose the input and output features for computing the gradient.
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(3);
+ dnums.set_kernel_output_feature_dimension(2);
+
+ // Mirror the filter in the spatial dimensions.
+ xla::ComputationDataHandle mirrored_weights =
+ ctx->builder()->Rev(filter, {dnums.kernel_spatial_dimensions(0),
+ dnums.kernel_spatial_dimensions(1)});
+
+ // activation gradients
+ // = gradients (with padding and dilation) <conv> mirrored_weights
+ xla::ComputationDataHandle in_backprop = ctx->builder()->ConvGeneralDilated(
+ out_backprop, mirrored_weights, /*window_strides=*/{1, 1},
+ /*padding=*/{{dims.rows.pad_before, dims.rows.pad_after},
+ {dims.cols.pad_before, dims.cols.pad_after}},
+ /*lhs_dilation=*/{dims.rows.stride, dims.cols.stride},
+ /*rhs_dilation=*/{1, 1}, dnums);
+
+ ctx->SetOutput(0, in_backprop);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropInputOp);
+};
+
+class Conv2DBackpropFilterOp : public XlaOpKernel {
+ public:
+ explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
+ int stride_n = GetTensorDim(strides_, data_format_, 'N');
+ int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ OP_REQUIRES(
+ ctx, (stride_n == 1 && stride_c == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape activations_shape = ctx->InputShape(0);
+ TensorShape filter_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
+ const TensorShape out_backprop_shape = ctx->InputShape(2);
+
+ // Reuse dimension computation logic from conv_grad_ops.cc.
+ Conv2DBackpropDimensions dims;
+ OP_REQUIRES_OK(
+ ctx, Conv2DBackpropComputeDimensions(
+ "Conv2DBackpropFilter", activations_shape, filter_shape,
+ out_backprop_shape, strides_, padding_, data_format_, &dims));
+
+ xla::ComputationDataHandle activations = ctx->Input(0);
+ xla::ComputationDataHandle gradients = ctx->Input(2);
+
+ // The filter gradients are computed by a convolution of the input
+ // activations and the output gradients, with some appropriate padding.
+ // See the comment at the top of conv_grad_ops.h for details.
+
+ xla::ConvolutionDimensionNumbers dnums;
+
+ // The activations (inputs) form the LHS of the convolution.
+ // Activations have shape: [batch, in_rows, in_cols, in_depth]
+ // For the gradient computation, we flip the roles of the batch and
+ // feature dimensions.
+ // Each spatial entry has size in_depth * batch
+ const int n_dim = GetTensorDimIndex(data_format_, 'N');
+ const int h_dim = GetTensorDimIndex(data_format_, 'H');
+ const int w_dim = GetTensorDimIndex(data_format_, 'W');
+ const int c_dim = GetTensorDimIndex(data_format_, 'C');
+
+ // Swap n_dim and c_dim in the activations.
+ dnums.set_batch_dimension(c_dim);
+ dnums.add_spatial_dimensions(h_dim);
+ dnums.add_spatial_dimensions(w_dim);
+ dnums.set_feature_dimension(n_dim);
+
+ // The gradients become the RHS of the convolution.
+ // The gradients have shape [batch, out_rows, out_cols, out_depth] where
+ // the batch becomes the input feature for the convolution.
+ dnums.add_kernel_spatial_dimensions(h_dim);
+ dnums.add_kernel_spatial_dimensions(w_dim);
+ dnums.set_kernel_input_feature_dimension(n_dim);
+ dnums.set_kernel_output_feature_dimension(c_dim);
+
+ // We will also need to pad the input with zeros such that after the
+ // convolution, we get the right size for the filter.
+ // The padded_in_rows should be such that when we convolve this with the
+ // expanded_out_rows as a filter, we should get filter_rows back.
+ //
+ const int padded_in_rows =
+ dims.rows.expanded_output_size + dims.rows.filter_size - 1;
+ const int padded_in_cols =
+ dims.cols.expanded_output_size + dims.cols.filter_size - 1;
+
+ // However it can be smaller than input_rows: in this
+ // case it means some of the inputs are not used.
+ //
+ // An example is to have input_cols = 3, filter_cols = 2 and stride = 2:
+ //
+ // INPUT = [ A B C ]
+ //
+ // FILTER = [ x y ]
+ //
+ // and the output will only have one column: a = A * x + B * y
+ //
+ // and input "C" is not used at all.
+ //
+ // We apply negative padding in this case.
+ const int total_pad_in_rows = padded_in_rows - dims.rows.input_size;
+ const int total_pad_in_cols = padded_in_cols - dims.cols.input_size;
+
+ // + For the VALID padding, we don't pad anything on the top/left side
+ // and pad the bottom/right side with the remaining space.
+ // + For the SAME padding, we pad top/left side the same as bottom/right
+ // side.
+ //
+ // In addition, if the padded input size is smaller than the input size,
+ // we need to ignore some training elements of the input. We do this by
+ // applying negative padding on the right/bottom.
+ const int top_pad_in_rows =
+ (total_pad_in_rows > 0 && padding_ == Padding::SAME)
+ ? total_pad_in_rows / 2
+ : 0;
+ const int left_pad_in_cols =
+ (total_pad_in_cols > 0 && padding_ == Padding::SAME)
+ ? total_pad_in_cols / 2
+ : 0;
+
+ // Besides padding the input, we will also expand output_rows to
+ // expanded_out_rows = (output_rows - 1) * stride + 1
+ // with zeros in between:
+ //
+ // a . . . b . . . c . . . d . . . e
+ //
+ // This is done by specifying the window dilation factors in the
+ // convolution HLO below.
+ auto filter_backprop = ctx->builder()->ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{top_pad_in_rows, total_pad_in_rows - top_pad_in_rows},
+ {left_pad_in_cols, total_pad_in_cols - left_pad_in_cols}},
+ /*lhs_dilation=*/{1, 1},
+ /*rhs_dilation=*/{dims.rows.stride, dims.cols.stride}, dnums);
+
+ // The layout of filter_backprop will match the layout of
+ // padded_activations
+ // and so will have layout: [out_feature, h, w, in_feature]
+ // Tensorflow filter shape is [ H, W, inC, outC ], so we transpose the
+ // output.
+ xla::ComputationDataHandle filter_backprop_reshaped =
+ ctx->builder()->Transpose(filter_backprop,
+ {h_dim, w_dim, c_dim, n_dim});
+ ctx->SetOutput(0, filter_backprop_reshaped);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ TensorFormat data_format_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv2DBackpropFilterOp);
+};
+
+REGISTER_XLA_OP("Conv2DBackpropInput", Conv2DBackpropInputOp);
+REGISTER_XLA_OP("Conv2DBackpropFilter", Conv2DBackpropFilterOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
new file mode 100644
index 0000000000..3cd0b39c87
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -0,0 +1,177 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific base classes for Unary and Binary Ops.
+
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+
+void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
+ const TensorShape lhs_shape = ctx->InputShape(0);
+ const TensorShape rhs_shape = ctx->InputShape(1);
+
+ // By TensorFlow conventions the inputs may not have the same
+ // shapes, in which case they will be automatically broadcast if
+ // possible before mapping. Use the standard TensorFlow helper to
+ // compute valid broadcast shapes, but rely below on XLA to
+ // automatically perform the broadcast assuming its valid shapes are
+ // a superset of TensorFlow's valid shapes.
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ if (!bcast.IsValid()) {
+ ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
+ lhs_shape.DebugString(), " vs. ",
+ rhs_shape.DebugString()));
+ return;
+ }
+ TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
+
+ // Fetch the expressions containing the input tensors.
+ auto lhs_handle = ctx->Input(0);
+ auto rhs_handle = ctx->Input(1);
+
+ // If the ranks of the inputs don't match, TensorFlow automatically
+ // reshapes the smaller by padding with dimensions of size 1 as a
+ // prefix. In other words to pad a 5-vector to a 3-dimensional
+ // tensor it is reshaped to have shape [1,1,5]. XLA's automatic
+ // broadcast code is able to broadcast from lower to higher rank,
+ // but doesn't assume you want to pad as a prefix of the dimensions,
+ // and instead needs to be told which dimensions of the higher rank
+ // tensor to match to the lower rank tensor. In this example it
+ // would be dimensions [2]. If we were matching a matrix against a
+ // 4-D tensor the dimensions to match would be [2,3],
+ // etc. extend_dimension encodes the general case.
+ std::vector<int64> extend_dimension;
+ int max_rank = std::max(lhs_shape.dims(), rhs_shape.dims());
+ int min_rank = std::min(lhs_shape.dims(), rhs_shape.dims());
+ if (min_rank != max_rank) {
+ for (int i = 0; i < min_rank; ++i) {
+ // Match the lower rank tensor along the larger-numbered
+ // dimensions of the higher rank tensor.
+ extend_dimension.push_back(max_rank - min_rank + i);
+ }
+ }
+
+ // Call virtual method to emit the computation.
+ xla::ComputationDataHandle output =
+ Computation(ctx, lhs_handle, lhs_shape.dim_sizes(), rhs_handle,
+ rhs_shape.dim_sizes(), bcast, extend_dimension);
+
+ // The TensorFlow helper computed the post-broadcast shape in
+ // output_shape: we rely on subclassed Computations to implement the
+ // same broadcast semantics.
+ ctx->SetOutput(0, output);
+}
+
+/* static */ std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
+XlaBinaryOp::Broadcast(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& lhs,
+ const xla::ComputationDataHandle& rhs,
+ const BCast& broadcast_helper) {
+ // Manually construct the broadcasting since MapN does not do
+ // automatic broadcasting. The bcast helper ensures that
+ // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
+ // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
+ // the same shape, so can be operated on by MapN.
+
+ // First reshape the inputs, which should be a metadata-only
+ // operation since we are flattening the dimensions in order.
+ auto lhs_shaped = builder->Reshape(lhs, broadcast_helper.x_reshape());
+ auto rhs_shaped = builder->Reshape(rhs, broadcast_helper.y_reshape());
+
+ // Next broadcast the necessary input dimensions. We rely on the
+ // XLA optimizer to be smart about the fact that we are asking
+ // it to broadcast size 1 on some of these dimensions, to avoid
+ // adding complexity to this code.
+ auto lhs_broadcast =
+ builder->Broadcast(lhs_shaped, broadcast_helper.x_bcast());
+ int lhs_size = broadcast_helper.x_bcast().size();
+ auto rhs_broadcast =
+ builder->Broadcast(rhs_shaped, broadcast_helper.y_bcast());
+ int rhs_size = broadcast_helper.y_bcast().size();
+
+ // Now reshape them to the correct output shape. After the
+ // broadcast each side is twice as wide as it should be, since the
+ // broadcast dimensions were prepended to the shape. Reshape
+ // flattening each original dimension with the prepended broadcast
+ // dimension. E.g. if we started out with lhs_shaped with shape
+ // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
+ // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
+ std::vector<int64> lhs_reorder;
+ for (int i = 0; i < lhs_size; ++i) {
+ lhs_reorder.push_back(i);
+ lhs_reorder.push_back(i + lhs_size);
+ }
+ auto lhs_output = builder->Reshape(lhs_broadcast, lhs_reorder,
+ broadcast_helper.output_shape());
+ std::vector<int64> rhs_reorder;
+ for (int i = 0; i < rhs_size; ++i) {
+ rhs_reorder.push_back(i);
+ rhs_reorder.push_back(i + rhs_size);
+ }
+ auto rhs_output = builder->Reshape(rhs_broadcast, rhs_reorder,
+ broadcast_helper.output_shape());
+
+ return {lhs_output, rhs_output};
+}
+
+xla::ComputationDataHandle XlaBinaryMapOp::Computation(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
+ const gtl::ArraySlice<int64>& lhs_shape,
+ const xla::ComputationDataHandle& rhs,
+ const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
+ const std::vector<int64>& extend_dimensions) {
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ // Construct the builder for the lambda computation.
+ xla::ComputationBuilder l(builder->client(), ctx->op_kernel().name());
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
+
+ // Make two scalar parameters of the desired type for the lambda.
+ xla::ComputationDataHandle x =
+ l.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ xla::ComputationDataHandle y =
+ l.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
+
+ // Call virtual method to build the lambda.
+ BuildMapLambda(&l, x, y);
+ xla::Computation computation = l.Build().ConsumeValueOrDie();
+
+ xla::ComputationDataHandle lhs_broadcast = lhs;
+ xla::ComputationDataHandle rhs_broadcast = rhs;
+ if (lhs_shape == rhs_shape) {
+ // There's no broadcasting to do.
+ CHECK_EQ(0, extend_dimensions.size());
+ return builder->Map({lhs, rhs}, computation);
+ } else {
+ std::tie(lhs_broadcast, rhs_broadcast) =
+ Broadcast(builder, lhs, rhs, broadcast_helper);
+ }
+ // Now the two sides are broadcast to the final shape we can do the map.
+ return builder->Map({lhs_broadcast, rhs_broadcast}, computation);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
new file mode 100644
index 0000000000..f0687c1d4b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific base classes for Unary and Binary Ops.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/bcast.h"
+
+namespace tensorflow {
+
+// Coefficient-wise binary operations. Each binary Op expects two
+// inputs that can be broadcast to the same shape. The base class
+// contains pure virtual methods to override: description is a textual
+// description of the operation; and Computation adds the
+// implementation of the operation to a xla::ComputationBuilder. For most
+// arithmetic Ops XLA handles the broadcasting automatically given the input
+// tensors. Ops like ReluGrad that need to map a scalar function over the inputs
+// can use the XlaBinaryMapOp subclass below which handles manual
+// broadcasting of the inputs.
+class XlaBinaryOp : public XlaOpKernel {
+ public:
+ explicit XlaBinaryOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ const DataType lhs = BaseType(input_type(0));
+ const DataType rhs = BaseType(input_type(1));
+ OP_REQUIRES(ctx, lhs == rhs,
+ errors::InvalidArgument("Input types of binary op must match"));
+ }
+ ~XlaBinaryOp() override {}
+
+ // Implement the (tensor,tensor)->tensor lambda that should be
+ // applied to the inputs. The desired computation should be added to
+ // 'tc->builder()' and '(lhs,rhs)' are the function's inputs and
+ // (lhs_shape,rhs_shape) are their respective
+ // shapes. 'broadcast_helper' contains metadata about the shapes of
+ // the inputs and the dimensions that need to be broadcast, which
+ // may be useful for Ops that can't use standard XLA automatic
+ // broadcasting. 'extend_dimension' is non-empty if lhs and rhs have
+ // different ranks, and indicates which dimensions of the
+ // higher-rank input should be matched when broadcasting the
+ // lower-rank input. See comment below and the documentation on broadcasting
+ // in the XLA documentation.
+ virtual xla::ComputationDataHandle Computation(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
+ const gtl::ArraySlice<int64>& lhs_shape,
+ const xla::ComputationDataHandle& rhs,
+ const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
+ const std::vector<int64>& extend_dimensions) = 0;
+
+ void Compile(XlaOpKernelContext* ctx) override;
+
+ // Helper function that performs the broadcasting described by
+ // 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
+ // shape.
+ static std::pair<xla::ComputationDataHandle, xla::ComputationDataHandle>
+ Broadcast(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& lhs,
+ const xla::ComputationDataHandle& rhs,
+ const BCast& broadcast_helper);
+};
+
+// Coefficient-wise binary operations that map a scalar function. Each
+// BinaryMap Op expects two inputs that can be broadcast to the same
+// shape and maps a (scalar,scalar)->scalar function across the zipped
+// elements of its (broadcast) inputs. The base class contains pure
+// virtual methods to override: description is a textual description
+// of the mapped function; and BuildMapLambda adds the
+// implementation of the lambda to a xla::ComputationBuilder.
+class XlaBinaryMapOp : public XlaBinaryOp {
+ public:
+ explicit XlaBinaryMapOp(OpKernelConstruction* ctx) : XlaBinaryOp(ctx) {}
+ ~XlaBinaryMapOp() override {}
+
+ // Implement the (scalar,scalar)->scalar lambda that should be
+ // applied to each pair of elements of the inputs. The desired
+ // computation should be added to 'builder' and
+ // '(scalar_lhs,scalar_rhs)' are the function's inputs.
+ virtual void BuildMapLambda(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) = 0;
+
+ xla::ComputationDataHandle Computation(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& lhs,
+ const gtl::ArraySlice<int64>& lhs_shape,
+ const xla::ComputationDataHandle& rhs,
+ const gtl::ArraySlice<int64>& rhs_shape, const BCast& broadcast_helper,
+ const std::vector<int64>& extend_dimensions) override;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_CWISE_OPS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/declaration_op.cc b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc
new file mode 100644
index 0000000000..d96ff34178
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/declaration_op.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+// This OpKernel implements the Constant Op for XLA JIT
+// devices. It extracts the constant Tensor from the Proto at kernel
+// construction time, and then every time the Constant Op is executed
+// an expression containing the constant is compiled.
+class ConstantDeclarationOp : public XlaOpKernel {
+ public:
+ explicit ConstantDeclarationOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx), tensor_(ctx->output_type(0)) {
+ const TensorProto* proto = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
+ // MakeTensorFromProto uses the cpu_allocator, so tensor_ is a
+ // "real" tensor backed by CPU memory, holding the value of the
+ // constant.
+ OP_REQUIRES_OK(ctx, MakeTensorFromProto(*proto, &tensor_));
+ OP_REQUIRES(
+ ctx, ctx->output_type(0) == tensor_.dtype(),
+ errors::InvalidArgument(
+ "Type mismatch between value (", DataTypeString(tensor_.dtype()),
+ ") and dtype (", DataTypeString(ctx->output_type(0)), ")"));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->SetConstantOutput(0, tensor_);
+ }
+
+ private:
+ // Extract the value of the constant from the Proto during Op kernel
+ // construction. The constant must be stored in a Tensor allocated
+ // using the cpu_allocator so that it is backed by real memory. The
+ // OpKernelConstruction's default allocator is the JITAllocator
+ // which only allocates enough space for metadata for each Tensor.
+ static Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ Tensor* tensor) {
+ Tensor parsed(tensor_proto.dtype());
+ if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
+ return errors::InvalidArgument("Cannot parse tensor from proto: ",
+ tensor_proto.DebugString());
+ }
+ *tensor = parsed;
+ return Status::OK();
+ }
+
+ // This is a "real" tensor backed by CPU memory, containing the
+ // constant values.
+ Tensor tensor_;
+ TF_DISALLOW_COPY_AND_ASSIGN(ConstantDeclarationOp);
+};
+
+REGISTER_XLA_OP("Const", ConstantDeclarationOp);
+
+// This OpKernel implements the _Arg Op for XLA JIT devices. It
+// associates its output with one of the arguments to a
+// subcomputation.
+class ArgOp : public XlaOpKernel {
+ public:
+ explicit ArgOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype_, &type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // If 'frame' is non-null, this is a function call inside an outer JIT
+ // compilation. Use the usual implementation of _Arg.
+ auto frame = ctx->call_frame();
+ if (frame != nullptr) {
+ Tensor val;
+ OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
+ OP_REQUIRES(ctx, val.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(val.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ // Forwards the argument from the frame.
+ ctx->op_kernel_context()->set_output(0, val);
+ return;
+ }
+
+ XlaContext& tc = XlaContext::Get(ctx);
+
+ OP_REQUIRES(ctx, 0 <= index_ && index_ < tc.args().size(),
+ errors::InvalidArgument("Invalid argument index ", index_));
+ const XlaCompiler::Argument& arg = tc.args()[index_];
+
+ if (arg.parameter < 0) {
+ ctx->SetConstantOutput(0, arg.constant_value);
+ } else {
+ ctx->SetOutput(0, tc.parameter(arg.parameter));
+ }
+ }
+
+ private:
+ int index_;
+ DataType dtype_;
+ xla::PrimitiveType type_; // Corresponding XLA type.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgOp);
+};
+
+REGISTER_XLA_OP("_Arg", ArgOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
new file mode 100644
index 0000000000..d408ab3338
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/depthwise_conv_ops.cc
@@ -0,0 +1,235 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for 2D depthwise convolution.
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/depthwise_conv_op.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+// Name of the function to use as the implementation for depthwise 2D
+// convolution. Default is empty string; another possible value is
+// "DummyDepthwiseConv2dKernel".
+static const char kDepthwiseConv2dCustomFunc[] = "";
+
+class DepthwiseConv2dNativeOp : public XlaOpKernel {
+ public:
+ explicit DepthwiseConv2dNativeOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ // TODO(keveman): Refactor this (and other XLA OpKernel constructors) so
+ // that they use a common implementation shared with non-XLA kernels.
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
+ OP_REQUIRES(ctx, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(ctx, strides_[1] == strides_[2],
+ errors::InvalidArgument(
+ "Current implementation only supports equal length "
+ "strides in the row and column dimensions."));
+ OP_REQUIRES(
+ ctx, (strides_[0] == 1 && strides_[3] == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_rows, in_cols, in_depth ]
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ // Input filter is of the following dimensions:
+ // [ filter_rows, filter_cols, in_depth, depth_multiplier]
+ const TensorShape filter_shape = ctx->InputShape(1);
+
+ // For 2D convolution, there should be 4 dimensions.
+ OP_REQUIRES(ctx, input_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_shape.DebugString()));
+ OP_REQUIRES(ctx, filter_shape.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter_shape.DebugString()));
+
+ // The last dimension for input is in_depth. It must be the same as the
+ // filter's in_depth.
+ const int64 in_depth = input_shape.dim_size(3);
+ OP_REQUIRES(
+ ctx, in_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument("input and filter must have the same depth: ",
+ in_depth, " vs ", filter_shape.dim_size(2)));
+
+ // The last dimension for filter is depth multiplier.
+ const int64 depth_multiplier = filter_shape.dim_size(3);
+
+ // The output depth is input depth x depth multiplier.
+ const int64 out_depth = in_depth * depth_multiplier;
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows = input_shape.dim_size(1);
+ const int64 filter_rows = filter_shape.dim_size(0);
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols = input_shape.dim_size(2);
+ const int64 filter_cols = filter_shape.dim_size(1);
+
+ // The first dimension for input is batch.
+ const int64 batch = input_shape.dim_size(0);
+
+ // For now we take the stride from the second dimension only (we
+ // assume row = col stride, and do not support striding on the
+ // batch or depth dimension).
+ const int32 stride = strides_[1];
+
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(ctx, GetWindowedOutputSize(input_rows, filter_rows, stride,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(ctx, GetWindowedOutputSize(input_cols, filter_cols, stride,
+ padding_, &out_cols, &pad_cols));
+ TensorShape out_shape({batch, out_rows, out_cols, out_depth});
+ OP_REQUIRES(
+ ctx, out_shape.num_elements() <= 2147483647,
+ errors::InvalidArgument("total number of outputs should be within the "
+ "range of int which is used in the GPU kernel",
+ in_depth, " vs ", filter_shape.dim_size(2)));
+
+ // Output tensor is of the following dimensions:
+ // [ in_batch, out_rows, out_cols, out_depth ]
+
+ VLOG(2) << "DepthwiseConv2dNative: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; stride = " << stride << ", pad_rows = " << pad_rows
+ << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
+ << out_rows << ", " << out_cols << ", " << out_depth << "]";
+
+ xla::ComputationBuilder& b = *ctx->builder();
+ xla::ComputationDataHandle input = ctx->Input(0);
+ xla::ComputationDataHandle filter = ctx->Input(1);
+ xla::ComputationDataHandle output;
+
+ const string custom_function_name = kDepthwiseConv2dCustomFunc;
+ if (!custom_function_name.empty()) {
+ xla::Shape xla_out_shape;
+ OP_REQUIRES_OK(
+ ctx, TensorShapeToXLAShape(input_type(0), out_shape, &xla_out_shape));
+
+ // The custom function for depthwise should interpret its arguments
+ // as follows :
+ // func(T* output,
+ // const T* input, const T* filter,
+ // const int32* input_size, const int32* filter_size,
+ // const int32* output_size,
+ // int32 stride, int32 pad_rows, int32 pad_cols)
+ //
+ // where T is the type of Tensor that this kernel is registered for.
+ // Note that the custom call op passes uses the following calling
+ // convention:
+ // func(void* output, void** inputs)
+ //
+ // Therefore the custom function should first construct the above
+ // inputs by unparsing the second argument passed to it.
+ output = b.CustomCall(
+ custom_function_name,
+ {input, filter,
+ b.ConstantR1<int64>({batch, input_rows, input_cols, in_depth}),
+ b.ConstantR1<int64>(
+ {filter_rows, filter_cols, in_depth, depth_multiplier}),
+ b.ConstantR1<int64>({batch, out_rows, out_cols, out_depth}),
+ b.ConstantR0<int64>(stride), b.ConstantR0<int64>(pad_rows),
+ b.ConstantR0<int64>(pad_cols)},
+ xla_out_shape);
+ } else {
+ // These will be used to define the bounds of each slice.
+ // Within the loop, the input_channel index will be modified.
+ gtl::InlinedVector<int64, 4> filter_begin;
+ gtl::InlinedVector<int64, 4> filter_limits;
+ gtl::InlinedVector<int64, 4> input_begin;
+ gtl::InlinedVector<int64, 4> input_limits;
+ for (int i = 0; i < 4; ++i) {
+ filter_begin.push_back(0);
+ filter_limits.push_back(filter_shape.dim_size(i));
+ input_begin.push_back(0);
+ input_limits.push_back(input_shape.dim_size(i));
+ }
+
+ std::vector<int64> strides_for_tla{strides_[1], strides_[2]};
+
+ xla::Padding xla_padding =
+ (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ xla::ConvolutionDimensionNumbers dims;
+ dims.set_batch_dimension(0);
+ dims.set_feature_dimension(3);
+ dims.add_spatial_dimensions(1);
+ dims.add_spatial_dimensions(2);
+
+ // TF filter shape is [ H, W, inC, outC ]
+ dims.add_kernel_spatial_dimensions(0);
+ dims.add_kernel_spatial_dimensions(1);
+ dims.set_kernel_input_feature_dimension(2);
+ dims.set_kernel_output_feature_dimension(3);
+
+ // Create one convolution for each input channel
+ std::vector<xla::ComputationDataHandle> convs;
+ for (int i = 0; i < in_depth; ++i) {
+ filter_begin[2] = i;
+ filter_limits[2] = i + 1;
+ input_begin[3] = i;
+ input_limits[3] = i + 1;
+
+ xla::ComputationDataHandle filter_slice =
+ b.Slice(filter, filter_begin, filter_limits);
+ xla::ComputationDataHandle input_slice =
+ b.Slice(input, input_begin, input_limits);
+ convs.push_back(b.ConvWithGeneralDimensions(
+ input_slice, filter_slice, strides_for_tla, xla_padding, dims));
+ }
+ // Concatenate the per-channel convolutions along the depth dimension.
+ output = b.ConcatInDim(convs, 3);
+ }
+
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
+};
+
+REGISTER_XLA_OP("DepthwiseConv2dNative", DepthwiseConv2dNativeOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
new file mode 100644
index 0000000000..b89109ff6a
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc
@@ -0,0 +1,255 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class DiagOp : public XlaOpKernel {
+ public:
+ explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ auto dims = input_shape.dim_sizes();
+ OP_REQUIRES(ctx, !dims.empty(),
+ errors::InvalidArgument("Expected 1 <= dims, got shape ",
+ input_shape.DebugString()));
+
+ xla::ComputationDataHandle diag = ctx->Input(0);
+
+ // Picture:
+ // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0]
+ // [0, 2, 0, 0]
+ // [0, 0, 3, 0]
+ // [0, 0, 0, 4]]
+
+ // Flattens the input to 1D.
+ int64 size = input_shape.num_elements();
+ diag = builder->Reshape(diag, {size});
+
+ // Adds inter-element padding of 'size'.
+ xla::PaddingConfig config;
+ auto* dim = config.add_dimensions();
+ dim->set_interior_padding(size);
+ diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config);
+
+ // Reshapes to the final shape.
+ std::vector<int64> new_dims(dims.size() * 2);
+ std::copy(dims.begin(), dims.end(), new_dims.begin());
+ std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size());
+ diag = builder->Reshape(diag, new_dims);
+
+ ctx->SetOutput(0, diag);
+ }
+};
+
+REGISTER_XLA_OP("Diag", DiagOp);
+
+class DiagPartOp : public XlaOpKernel {
+ public:
+ explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ const TensorShape input_shape = ctx->InputShape(0);
+ auto dims = input_shape.dim_sizes();
+
+ int num_dims = dims.size();
+ const int out_dims = num_dims / 2;
+
+ OP_REQUIRES(ctx, 2 <= num_dims,
+ errors::InvalidArgument("Expected 2 <= dims, got shape ",
+ input_shape.DebugString()));
+ OP_REQUIRES(ctx, num_dims % 2 == 0,
+ errors::InvalidArgument("The input tensor must have even rank; "
+ "got shape ",
+ input_shape.DebugString()));
+ int64 new_size = 1;
+ std::vector<int64> new_dims;
+ for (int i = 0; i < out_dims; i++) {
+ OP_REQUIRES(
+ ctx, dims[i] == dims[i + out_dims],
+ errors::InvalidArgument("Invalid shape ", input_shape.DebugString(),
+ ": dimensions ", i, " and ", i + out_dims,
+ " do not match."));
+ new_size *= dims[i];
+ new_dims.push_back(dims[i]);
+ }
+
+ xla::ComputationDataHandle diag = ctx->Input(0);
+
+ // TODO(b/30878775): use Slice with strides when supported, in place of
+ // the Pad -> Reshape -> Slice.
+
+ // Picture:
+ // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0],
+ // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0],
+ // [0, 0, 3, 0] [3, 0, 0, 0, 0],
+ // [0, 0, 0, 4]] [4, 0, 0, 0, 0]]
+ // and then slice out the first column.
+
+ // Flattens the input to 1D.
+ int64 size = input_shape.num_elements();
+ diag = builder->Reshape(diag, {size});
+
+ // Adds padding after the last element of 'new_size'.
+ xla::PaddingConfig config;
+ auto* dim = config.add_dimensions();
+ dim->set_edge_padding_high(new_size);
+ auto zero = XlaHelpers::Zero(builder, input_type(0));
+ diag = builder->Pad(diag, zero, config);
+
+ // Reshapes so the diagonal is now in the first column.
+ diag = builder->Reshape(diag, {new_size, new_size + 1});
+
+ // Slices out the first column and reshapes to the final shape.
+ diag = builder->Slice(diag, {0, 0}, {new_size, 1});
+ diag = builder->Reshape(diag, new_dims);
+
+ ctx->SetOutput(0, diag);
+ }
+};
+
+REGISTER_XLA_OP("DiagPart", DiagPartOp);
+
+class MatrixDiagOp : public XlaOpKernel {
+ public:
+ explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ auto dims = input_shape.dim_sizes();
+ OP_REQUIRES(ctx, !dims.empty(),
+ errors::InvalidArgument("Expected 1 <= dims, got shape ",
+ input_shape.DebugString()));
+
+ xla::ComputationDataHandle diag = ctx->Input(0);
+
+ int last_dim = dims.size() - 1;
+ int64 last_dim_size = input_shape.dim_size(last_dim);
+
+ // Adds inter-element padding of 'last_dim_size' to the last dimension.
+ xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size());
+ auto* dim = config.mutable_dimensions(last_dim);
+ dim->set_interior_padding(last_dim_size);
+ diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config);
+
+ // Reshapes to the final shape.
+ dims.push_back(last_dim_size);
+ diag = builder->Reshape(diag, dims);
+
+ ctx->SetOutput(0, diag);
+ }
+};
+
+REGISTER_XLA_OP("MatrixDiag", MatrixDiagOp);
+
+class MatrixDiagPartOp : public XlaOpKernel {
+ public:
+ explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ const TensorShape input_shape = ctx->InputShape(0);
+ auto dims = input_shape.dim_sizes();
+
+ OP_REQUIRES(ctx, 2 <= dims.size(),
+ errors::InvalidArgument("Expected 2 <= dims, got shape ",
+ input_shape.DebugString()));
+
+ xla::ComputationDataHandle diag = ctx->Input(0);
+
+ int last_dim = dims.size() - 1;
+ int64 last_dim_size = dims[last_dim];
+
+ // The smaller of the last two dimension sizes.
+ int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]);
+
+ // TODO(b/30878775): use Slice with strides when supported, in place of
+ // the Pad -> Reshape -> Slice.
+
+ // Picture: for each 2D matrix in the tensor's last two dimensions:
+ // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0],
+ // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0],
+ // [0, 0, 3, 0]] [3, 0, 0, 0, 0],
+ // and then slice out the first column.
+ //
+ // Another example, with tall and narrow input.
+ // [[1, 0] pad and reshape to [[1, 0, 0],
+ // [0, 2] =================> [2, 0, 0]]
+ // [0, 0]
+ // [0, 0]]
+
+ // Collapses the last two dimensions.
+ std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1);
+ flattened_dims.back() *= dims.back();
+ diag = builder->Reshape(diag, flattened_dims);
+
+ // Slices or pads the last dimension to 'target_size'.
+ int64 actual_size = flattened_dims.back();
+ int64 target_size = smaller_dim_size * (last_dim_size + 1);
+ if (actual_size < target_size) {
+ xla::PaddingConfig config =
+ xla::MakeNoPaddingConfig(flattened_dims.size());
+ auto* dim = config.mutable_dimensions(flattened_dims.size() - 1);
+ dim->set_edge_padding_high(target_size - actual_size);
+ auto zero = XlaHelpers::Zero(builder, input_type(0));
+ diag = builder->Pad(diag, zero, config);
+ } else if (actual_size > target_size) {
+ std::vector<int64> start(flattened_dims.size(), 0);
+ std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end());
+ limits[flattened_dims.size() - 1] = target_size;
+ diag = builder->Slice(diag, start, limits);
+ }
+
+ // Reshape so the target values are in the first position of the last
+ // dimension.
+ std::vector<int64> unflattened_dims(dims.begin(), dims.end());
+ dims[last_dim - 1] = smaller_dim_size;
+ dims[last_dim] = last_dim_size + 1;
+ diag = builder->Reshape(diag, dims);
+
+ // Slices out the first column and reshapes to the final shape.
+ std::vector<int64> start(dims.size(), 0);
+ std::vector<int64> limits(dims.begin(), dims.end());
+ limits[last_dim] = 1;
+ diag = builder->Slice(diag, start, limits);
+
+ // Collapses away the last dimension.
+ dims.pop_back();
+ diag = builder->Reshape(diag, dims);
+
+ ctx->SetOutput(0, diag);
+ }
+};
+
+REGISTER_XLA_OP("MatrixDiagPart", MatrixDiagPartOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
new file mode 100644
index 0000000000..2936e79261
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
@@ -0,0 +1,200 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific dynamic stitch Op.
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+namespace {
+
+class DynamicStitchOp : public XlaOpKernel {
+ public:
+ explicit DynamicStitchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() > 0,
+ errors::InvalidArgument("DynamicStitchOp: Must have some inputs"));
+ OP_REQUIRES(ctx, ctx->num_inputs() % 2 == 0,
+ errors::InvalidArgument(
+ "DynamicStitchOp: Must have even number of arguments"));
+ // Compute expected input signature
+ const int n = ctx->num_inputs() / 2;
+ const DataType dt = ctx->input_type(n);
+ DataTypeVector expected;
+ for (int i = 0; i < n; i++) {
+ expected.push_back(DT_INT32);
+ }
+ for (int i = 0; i < n; i++) {
+ expected.push_back(dt);
+ }
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected, {dt}));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // Validate that data_shape[i] = indices[i].shape() + constant
+ std::vector<xla::Literal> indices_input;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputList("indices", &indices_input));
+
+ std::vector<xla::ComputationDataHandle> data;
+ std::vector<TensorShape> data_shapes;
+ OP_REQUIRES_OK(ctx, ctx->InputList("data", &data, &data_shapes));
+
+ std::vector<xla::Literal> indices(indices_input.size());
+
+ const TensorShape& data0_shape = data_shapes[0];
+ const TensorShape indices0_shape =
+ XLAShapeToTensorShape(indices_input[0].shape());
+ for (int input_num = 0; input_num < indices_input.size(); input_num++) {
+ const TensorShape indices_shape =
+ XLAShapeToTensorShape(indices_input[input_num].shape());
+ const TensorShape& data_shape = data_shapes[input_num];
+ OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape),
+ errors::InvalidArgument(
+ "data[", input_num, "].shape = ",
+ data_shape.DebugString(), " does not start with indices[",
+ input_num, "].shape = ", indices_shape.DebugString()));
+ OP_REQUIRES(ctx,
+ input_num == 0 || SameExtraShape(data0_shape, indices0_shape,
+ data_shape, indices_shape),
+ errors::InvalidArgument(
+ "Need data[0].shape[", indices0_shape.dims(),
+ ":] = data[", input_num, "].shape[", indices_shape.dims(),
+ ":], got data[0].shape = ", data0_shape.DebugString(),
+ ", data[", input_num, "].shape = ",
+ data_shape.DebugString(), ", indices[0].shape = ",
+ indices0_shape.DebugString(), ", indices[", input_num,
+ "].shape = ", indices_shape.DebugString()));
+
+ OP_REQUIRES_OK(ctx,
+ XlaHelpers::ReshapeLiteral(indices_input[input_num],
+ {indices_shape.num_elements()},
+ &indices[input_num]));
+ }
+
+ // Find which slice will be used for each index. If the same index
+ // appears in multiple inputs, the last one is used. The logic
+ // here is different from that in third_party/tensorflow because
+ // it is important for XLA that there be a well-formed Concat
+ // operation at the end. The existing CPU/GPU code copies multiple
+ // source slices to the same destination slice if there are
+ // repeated indices, whereas the XLA code works out which
+ // source slice will 'win' and only uses that in the Concat.
+ int max_index = -1;
+ for (int input_num = 0; input_num < indices.size(); input_num++) {
+ for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
+ max_index = std::max(
+ max_index, xla::LiteralUtil::Get<int>(indices[input_num], {i}));
+ }
+ }
+ int number_of_indices = max_index + 1;
+ OP_REQUIRES(ctx, number_of_indices > 0,
+ errors::InvalidArgument("no indices supplied"));
+ // Construct the reverse mapping, for each index, of which slice of which
+ // input it comes from.
+ std::vector<int32> src_input_vector(number_of_indices);
+ std::vector<int32> src_slice_vector(number_of_indices);
+ std::vector<bool> src_index_used(number_of_indices);
+ int index_used_count = 0;
+ for (int input_num = 0; input_num < indices.size(); input_num++) {
+ for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
+ int index = xla::LiteralUtil::Get<int>(indices[input_num], {i});
+ src_input_vector[index] = input_num;
+ src_slice_vector[index] = i;
+ if (!src_index_used[index]) {
+ src_index_used[index] = true;
+ ++index_used_count;
+ }
+ }
+ }
+ OP_REQUIRES(ctx, index_used_count == number_of_indices,
+ errors::InvalidArgument("not all indices are used"));
+
+ // Look up all the children expressions that represent the data
+ // inputs.
+ std::vector<xla::ComputationDataHandle> input(indices.size());
+ for (int input_num = 0; input_num < indices.size(); input_num++) {
+ TensorShape new_shape;
+ // first reshaped dimension is the number of indices for this input.
+ new_shape.AddDim(indices[input_num].shape().dimensions(0));
+ // Then the rest are the common extra shape.
+ for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
+ new_shape.AddDim(data0_shape.dim_size(d));
+ }
+ // Get the data, shaped appropriately.
+ auto handle = data[input_num];
+ if (new_shape == data_shapes[input_num]) {
+ input[input_num] = handle;
+ } else {
+ input[input_num] =
+ ctx->builder()->Reshape(handle, new_shape.dim_sizes());
+ }
+ }
+
+ // Set up the vectors for slicing: the first dimension will vary
+ // slice by slice, and the rest take the full common extra shape.
+ std::vector<int64> slice_start(1 + data0_shape.dims() -
+ indices0_shape.dims());
+ std::vector<int64> slice_limit(1 + data0_shape.dims() -
+ indices0_shape.dims());
+ for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) {
+ slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d);
+ }
+ std::vector<xla::ComputationDataHandle> to_concat(number_of_indices);
+ for (int index_num = 0; index_num < number_of_indices; index_num++) {
+ const auto& expression = input[src_input_vector[index_num]];
+ // Take the appropriate slice of data.
+ slice_start[0] = src_slice_vector[index_num];
+ slice_limit[0] = src_slice_vector[index_num] + 1;
+ // And place it in the concat list in the place indicated by
+ // the index.
+ to_concat[index_num] =
+ ctx->builder()->Slice(expression, slice_start, slice_limit);
+ }
+
+ ctx->SetOutput(0, ctx->builder()->ConcatInDim(to_concat, 0));
+ }
+
+ private:
+ // Check if data0_shape[indices0.dims():] == data1_shape[indices1.dims():]
+ static bool SameExtraShape(const TensorShape& data0_shape,
+ const TensorShape& indices0,
+ const TensorShape& data1_shape,
+ const TensorShape& indices1) {
+ const int extra0 = data0_shape.dims() - indices0.dims();
+ const int extra1 = data1_shape.dims() - indices1.dims();
+ if (extra0 != extra1) return false;
+ for (int i = 0; i < extra0; i++) {
+ if (data0_shape.dim_size(indices0.dims() + i) !=
+ data1_shape.dim_size(indices1.dims() + i)) {
+ return false;
+ }
+ }
+ return true;
+ }
+};
+
+REGISTER_XLA_OP("DynamicStitch", DynamicStitchOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
new file mode 100644
index 0000000000..918c80aad8
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Fill Op.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+namespace {
+
+class FillOp : public XlaOpKernel {
+ public:
+ explicit FillOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ // The output of this Op is a tensor of shape 'dims_shape' with each
+ // element set to the scalar 'dims_literal'.
+ const TensorShape dims_shape = ctx->InputShape(0);
+ const TensorShape value_shape = ctx->InputShape(1);
+ OP_REQUIRES(
+ ctx, IsLegacyVector(dims_shape),
+ errors::InvalidArgument("dims must be a vector of int32, got shape ",
+ dims_shape.DebugString()));
+ OP_REQUIRES(ctx, IsLegacyScalar(value_shape),
+ errors::InvalidArgument("value must be a scalar, got shape ",
+ value_shape.DebugString()));
+ // Evaluate the 'dims' constant input, reshaping to a vector if it
+ // was a 'legacy' vector (secretly a scalar).
+ xla::Literal dims_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(
+ 0, {dims_shape.num_elements()}, &dims_literal));
+
+ // Convert the dims literal into a vector that we can pass to
+ // ComputationBuilder.
+ std::vector<int64> broadcast;
+ for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) {
+ broadcast.push_back(xla::LiteralUtil::Get<int>(dims_literal, {i}));
+ }
+ // Look up the value input, reshaping to a scalar if it was a
+ // 'legacy' scalar (secretly a vector).
+ xla::ComputationDataHandle data = ctx->Input(1);
+ if (value_shape.dims() > 0) {
+ CHECK_EQ(value_shape.dims(), 1);
+ data = ctx->builder()->Reshape(data, {});
+ }
+ // Emit the actual computation, which broadcasts the scalar to the
+ // desired shape.
+ auto result = ctx->builder()->Broadcast(data, broadcast);
+
+ ctx->SetOutput(0, result);
+ }
+};
+
+REGISTER_XLA_OP("Fill", FillOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/function_ops.cc b/tensorflow/compiler/tf2xla/kernels/function_ops.cc
new file mode 100644
index 0000000000..53f2196dc5
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/function_ops.cc
@@ -0,0 +1,110 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+const char* const kGradientOp = "SymbolicGradient";
+
+// Implementations of _ListToArray and _ArrayToList for functions.
+class PassOn : public XlaOpKernel {
+ public:
+ explicit PassOn(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
+ errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
+ " vs. ", ctx->num_outputs()));
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ OP_REQUIRES(
+ ctx, input_type(i) == output_type(i),
+ errors::Internal("Input and output types for position ", i,
+ " do not match: ", DataTypeString(input_type(i)),
+ " vs. ", DataTypeString(output_type(i))));
+ }
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ ctx->SetOutput(i, ctx->Input(i));
+ }
+ }
+};
+
+REGISTER_XLA_OP("_ListToArray", PassOn);
+REGISTER_XLA_OP("_ArrayToList", PassOn);
+
+// TODO(phawkins): this is an almost exact copy of the SymbolicGradientOp
+// implementation from regular Tensorflow. Once XLA has been open sourced
+// merge the two implementations. (Note: this implementation propagates the
+// step_resource_manager).
+class SymbolicGradientOp : public AsyncOpKernel {
+ public:
+ explicit SymbolicGradientOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx), handle_(-1) {}
+
+ ~SymbolicGradientOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."),
+ done);
+
+ OP_REQUIRES_OK_ASYNC(
+ ctx, lib->Instantiate(kGradientOp, def().attr(), &handle_), done);
+
+ FunctionLibraryRuntime::Options opts;
+ opts.step_id = ctx->step_id();
+ opts.runner = ctx->runner();
+ opts.step_container = ctx->step_container();
+ std::vector<Tensor> args;
+ args.reserve(ctx->num_inputs());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ args.push_back(ctx->input(i));
+ }
+ std::vector<Tensor>* rets = new std::vector<Tensor>;
+ lib->Run(
+ opts, handle_, args, rets, [ctx, done, rets](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else if (rets->size() != ctx->num_outputs()) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "SymGrad expects to return ", ctx->num_outputs(),
+ " tensor(s), but get ", rets->size(), " tensor(s) instead."));
+ } else {
+ for (size_t i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ done();
+ });
+ }
+
+ private:
+ FunctionLibraryRuntime::Handle handle_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
+};
+
+REGISTER_XLA_OP(kGradientOp, SymbolicGradientOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
new file mode 100644
index 0000000000..b98d386479
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc
@@ -0,0 +1,104 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+class GatherOp : public XlaOpKernel {
+ public:
+ explicit GatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape params_shape = ctx->InputShape(0);
+ const TensorShape indices_shape = ctx->InputShape(1);
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVectorOrHigher(params_shape),
+ errors::InvalidArgument("params must be at least 1 dimensional"));
+
+ DataType index_type = input_type(1);
+ OP_REQUIRES(ctx, index_type == DT_INT32 || index_type == DT_INT64,
+ errors::InvalidArgument("index must be int32 or int64"));
+
+ // Check that we have enough index space.
+ const int64 limit = index_type == DT_INT32
+ ? std::numeric_limits<int32>::max()
+ : std::numeric_limits<int64>::max();
+ OP_REQUIRES(
+ ctx, params_shape.dim_size(0) <= limit,
+ errors::InvalidArgument("params.shape[0] too large for ",
+ DataTypeString(index_type), " indexing: ",
+ params_shape.dim_size(0), " > ", limit));
+
+ // The result shape is indices.shape + params.shape[1:].
+ TensorShape result_shape = indices_shape;
+ for (int i = 1; i < params_shape.dims(); i++) {
+ result_shape.AddDim(params_shape.dim_size(i));
+ }
+
+ XlaContext& tc = XlaContext::Get(ctx);
+ OP_REQUIRES(
+ ctx, tc.allow_cpu_custom_calls(),
+ errors::InvalidArgument("Gather op requires CustomCall on CPU"));
+
+ xla::ComputationBuilder& b = *ctx->builder();
+
+ // Call gather_xla_float_kernel (from gather_op_kernel_float.cc).
+ // XLA passes <out> to the function, so it is not included here.
+ std::vector<xla::ComputationDataHandle> args;
+ args.push_back(tc.GetOrCreateRuntimeContextParameter());
+ args.push_back(b.ConstantLiteral(
+ *xla::LiteralUtil::CreateR0<int64>(indices_shape.num_elements())));
+ args.push_back(b.ConstantLiteral(
+ *xla::LiteralUtil::CreateR0<int64>(params_shape.dim_size(0))));
+ args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0<int64>(
+ params_shape.num_elements() / params_shape.dim_size(0))));
+ args.push_back(ctx->Input(0));
+ args.push_back(ctx->Input(1));
+
+ xla::Shape xla_out_shape;
+ OP_REQUIRES_OK(
+ ctx, TensorShapeToXLAShape(DT_FLOAT, result_shape, &xla_out_shape));
+
+ // Call the custom code with args:
+ xla::ComputationDataHandle output;
+ if (index_type == DT_INT32) {
+ output = b.CustomCall("gather_float_int32_xla_impl", args, xla_out_shape);
+ } else {
+ output = b.CustomCall("gather_float_int64_xla_impl", args, xla_out_shape);
+ }
+
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(GatherOp);
+};
+
+REGISTER_XLA_OP("Gather", GatherOp);
+
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Gather")
+ .TypeConstraint("Tparams", DT_FLOAT)
+ .TypeConstraint("Tindices", {DT_INT32, DT_INT64}));
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc
new file mode 100644
index 0000000000..eff23bd77d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+
+namespace tensorflow {
+
+EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) {
+ // data is managed by the JIT code so msan can't tell it's initialized.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 6 * sizeof(void*));
+
+ int64 indices_size = *static_cast<int64*>(data[1]);
+ int64 params_x = *static_cast<int64*>(data[2]);
+ int64 params_y = *static_cast<int64*>(data[3]);
+
+ float* in = static_cast<float*>(data[4]);
+
+ int32* indices = static_cast<int32*>(data[5]);
+ Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes;
+ in_eig_sizes[0] = params_x;
+ in_eig_sizes[1] = params_y;
+ tensorflow::TTypes<float, 2>::ConstMatrix in_eig(in, in_eig_sizes);
+
+ Eigen::DSizes<Eigen::DenseIndex, 1> indices_eig_sizes;
+ indices_eig_sizes[0] = indices_size;
+ tensorflow::TTypes<int32>::ConstFlat indices_eig(indices, indices_eig_sizes);
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> out_eig_sizes;
+ out_eig_sizes[0] = indices_size;
+ out_eig_sizes[1] = params_y;
+ tensorflow::TTypes<float>::Matrix out_eig(out, out_eig_sizes);
+
+ tensorflow::functor::GatherFunctorCPU<float, int32> f;
+ const int64 bad_i = f(in_eig, indices_eig, out_eig);
+ if (bad_i != -1) {
+ tensorflow::XlaLocalRuntimeContext* runtime_context =
+ static_cast<tensorflow::XlaLocalRuntimeContext*>(data[0]);
+ runtime_context->error = true;
+ runtime_context->error_msg = "Invalid index for gather";
+ for (int i = 0; i < out_eig.size(); ++i) out[i] = 0;
+ }
+}
+
+} // namespace tensorflow
+
+// Implements gather on CPU. This is called by an XLA custom call, set up by
+// gather_op.cc.
+extern "C" void __attribute__((visibility("default")))
+gather_float_int32_xla_impl(float* out, void** data) {
+ tensorflow::gather_float_int32_xla_impl(out, data);
+}
diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc
new file mode 100644
index 0000000000..ae31f6f200
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+
+namespace tensorflow {
+
+EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) {
+ // data is managed by the JIT code so msan can't tell it's initialized.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 6 * sizeof(void*));
+
+ int64 indices_size = *static_cast<int64*>(data[1]);
+ int64 params_x = *static_cast<int64*>(data[2]);
+ int64 params_y = *static_cast<int64*>(data[3]);
+
+ float* in = static_cast<float*>(data[4]);
+
+ int64* indices = static_cast<int64*>(data[5]);
+ Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes;
+ in_eig_sizes[0] = params_x;
+ in_eig_sizes[1] = params_y;
+ tensorflow::TTypes<float, 2>::ConstMatrix in_eig(in, in_eig_sizes);
+
+ Eigen::DSizes<Eigen::DenseIndex, 1> indices_eig_sizes;
+ indices_eig_sizes[0] = indices_size;
+ tensorflow::TTypes<int64>::ConstFlat indices_eig(indices, indices_eig_sizes);
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> out_eig_sizes;
+ out_eig_sizes[0] = indices_size;
+ out_eig_sizes[1] = params_y;
+ tensorflow::TTypes<float>::Matrix out_eig(out, out_eig_sizes);
+
+ tensorflow::functor::GatherFunctorCPU<float, int64> f;
+ const int64 bad_i = f(in_eig, indices_eig, out_eig);
+ if (bad_i != -1) {
+ tensorflow::XlaLocalRuntimeContext* runtime_context =
+ static_cast<tensorflow::XlaLocalRuntimeContext*>(data[0]);
+ runtime_context->error = true;
+ runtime_context->error_msg = "Invalid index for gather";
+ for (int i = 0; i < out_eig.size(); ++i) out[i] = 0;
+ }
+}
+
+} // namespace tensorflow
+
+// Implements gather on CPU. This is called by an XLA custom call, set up by
+// gather_op.cc.
+extern "C" void __attribute__((visibility("default")))
+gather_float_int64_xla_impl(float* out, void** data) {
+ tensorflow::gather_float_int64_xla_impl(out, data);
+}
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
new file mode 100644
index 0000000000..01417a3cdf
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class IdentityOp : public XlaOpKernel {
+ public:
+ explicit IdentityOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->SetOutput(0, ctx->Input(0));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(IdentityOp);
+};
+
+REGISTER_XLA_OP("Identity", IdentityOp);
+REGISTER_XLA_OP("PreventGradient", IdentityOp);
+REGISTER_XLA_OP("StopGradient", IdentityOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
new file mode 100644
index 0000000000..293705e39f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Native XLA implementations of indexing ops.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+namespace {
+
+// The logic below uses a custom-call to implement argmax.
+//
+// TODO(toddw): We can implement argmax using existing XLA ops. The idea is
+// to use SelectAndScatter to create a tensor initialized to 0, where the max
+// value along dim is set to 1. Then take the dot-product of that against a
+// vector of indices [0,dim_size), which yields the result. As a detail, we
+// might need to reshape before and afterwards, since the XLA Dot operator
+// only performs the sum of products over dimension 0.
+//
+// rs = Reshape(input, ...) // reshape so dim is inner-most
+// one_max = SelectAndScatter(rs, greater_than,
+// {1,1,...,dim_size}, {1,1,...,dim_size},
+// VALID, [1], 0, add)
+// indices = [0,1,2,...,dim_size-1]
+// max_index = Dot(one_max, indices)
+// result = Reshape(max_index, ...) // reshape back to original
+//
+// Also see b/29507024 for first-class XLA support for indexing ops.
+
+class ArgMaxOp : public XlaOpKernel {
+ public:
+ explicit ArgMaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape dimension_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape),
+ errors::InvalidArgument(
+ "dim must be a scalar, but received tensor of shape: ",
+ dimension_shape.DebugString()));
+
+ // We require that the dimension argument is a constant, since it lets us
+ // dispatch to a specialized custom-call function without any run-time
+ // overhead, when compiling ahead-of-time.
+ //
+ // TODO(toddw): We could remove this requirement if necessary; we'd also
+ // need to update const_analysis.cc. However it seems likely that a native
+ // XLA op would have the same requirement.
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
+ const int32 dim = xla::LiteralUtil::Get<int32>(literal, {});
+ OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
+ OP_REQUIRES(
+ ctx, dim < input_shape.dims(),
+ errors::InvalidArgument("dim must be < input rank (",
+ input_shape.dims(), "), but got: ", dim));
+ const int64 dim_size = input_shape.dim_size(dim);
+ OP_REQUIRES(
+ ctx, dim_size > 0,
+ errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ",
+ input_shape.DebugString()));
+
+ // The output shape is the input shape contracted along dim.
+ TensorShape output_shape;
+ for (int d = 0; d < input_shape.dims() - 1; ++d) {
+ output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
+ }
+
+ // For now we use a custom-call, only for the 1d and 2d cases.
+ OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(),
+ errors::InvalidArgument(
+ "ArgMax implementation requires a CustomCall on CPU"));
+ xla::ComputationBuilder& b = *ctx->builder();
+
+ // XLA passes <out> to the function, so it is not included here.
+ std::vector<xla::ComputationDataHandle> args;
+ args.push_back(ctx->Input(0));
+ args.push_back(b.ConstantLiteral(
+ *xla::LiteralUtil::CreateR1<int64>(input_shape.dim_sizes())));
+ if (input_shape.dims() > 1) {
+ // Don't bother passing the output shape and dim for the 1d case, since
+ // the shape is always a scalar and the dim is always 0.
+ args.push_back(b.ConstantLiteral(
+ *xla::LiteralUtil::CreateR1<int64>(output_shape.dim_sizes())));
+ args.push_back(
+ b.ConstantLiteral(*xla::LiteralUtil::CreateR0<int32>(dim)));
+ }
+
+ xla::Shape xla_shape =
+ xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes());
+
+ // Tell XLA to call the custom code, defined in
+ // index_ops_kernel_argmax_float_1d.cc.
+ xla::ComputationDataHandle output;
+ switch (input_shape.dims()) {
+ case 1:
+ output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape);
+ break;
+ case 2:
+ output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape);
+ break;
+ default:
+ OP_REQUIRES(ctx, false,
+ errors::Unimplemented(
+ "Argmax is only implemented for 1d and 2d tensors"
+ ", but got shape: ",
+ input_shape.DebugString()));
+ }
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxOp);
+};
+
+REGISTER_XLA_OP("ArgMax", ArgMaxOp);
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("ArgMax").TypeConstraint("T", DT_FLOAT));
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc
new file mode 100644
index 0000000000..0033a949a3
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) {
+ // data is managed by the JIT code so msan can't tell it's initialized.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 2 * sizeof(void*));
+
+ float* input = static_cast<float*>(data[0]);
+ int64 input_size = *static_cast<int64*>(data[1]);
+
+ Eigen::DSizes<Eigen::DenseIndex, 1> in_eig_sizes(input_size);
+ TTypes<float, 1>::ConstTensor in_eig(input, in_eig_sizes);
+
+ Eigen::DSizes<Eigen::DenseIndex, 0> out_eig_sizes;
+ int64* out_t = static_cast<int64*>(out);
+ TTypes<int64, 0>::Tensor out_eig(out_t, out_eig_sizes);
+
+ out_eig = in_eig.argmax(0).cast<int64>();
+}
+
+} // namespace tensorflow
+
+// Implements argmax on CPU. This is called by an XLA custom call, set up by
+// index_ops.cc.
+extern "C" void __attribute__((visibility("default")))
+argmax_float_1d_xla_impl(void* out, void** data) {
+ tensorflow::argmax_float_1d_xla_impl(out, data);
+}
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc
new file mode 100644
index 0000000000..be8ad2317c
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) {
+ // data is managed by the JIT code so msan can't tell it's initialized.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 4 * sizeof(void*));
+
+ float* in = static_cast<float*>(data[0]);
+ int64* in_sizes = static_cast<int64*>(data[1]);
+ int64* out_sizes = static_cast<int64*>(data[2]);
+ int32 dim = *static_cast<int32*>(data[3]);
+
+ Eigen::DSizes<Eigen::DenseIndex, 2> in_eig_sizes(in_sizes[0], in_sizes[1]);
+ TTypes<float, 2>::ConstTensor in_eig(in, in_eig_sizes);
+
+ int64* out_t = static_cast<int64*>(out);
+ Eigen::DSizes<Eigen::DenseIndex, 1> out_eig_sizes(out_sizes[0]);
+ TTypes<int64, 1>::Tensor out_eig(out_t, out_eig_sizes);
+
+ out_eig = in_eig.argmax(dim).cast<int64>();
+}
+
+} // namespace tensorflow
+
+// Implements argmax on CPU. This is called by an XLA custom call, set up by
+// index_ops.cc.
+extern "C" void __attribute__((visibility("default")))
+argmax_float_2d_xla_impl(void* out, void** data) {
+ tensorflow::argmax_float_2d_xla_impl(out, data);
+}
diff --git a/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
new file mode 100644
index 0000000000..248984bcfe
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/l2loss_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/no_op.h"
+
+namespace tensorflow {
+namespace {
+
+class L2LossOp : public XlaOpKernel {
+ public:
+ explicit L2LossOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ DataType dtype = ctx->input_type(0);
+ xla::ComputationBuilder* b = ctx->builder();
+
+ auto zero = XlaHelpers::Zero(b, dtype);
+ auto two = XlaHelpers::IntegerLiteral(b, dtype, 2);
+ const xla::Computation& add = *ctx->GetOrCreateAdd(dtype);
+
+ std::vector<int64> dims(input_shape.dims());
+ std::iota(dims.begin(), dims.end(), 0);
+
+ // output = sum(t ** 2) / 2
+ auto x = ctx->Input(0);
+ ctx->SetOutput(0, b->Div(b->Reduce(b->Mul(x, x), zero, add, dims), two));
+ }
+};
+
+REGISTER_XLA_OP("L2Loss", L2LossOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
new file mode 100644
index 0000000000..93966d3d5a
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/lrn_ops.cc
@@ -0,0 +1,173 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+// Local response normalization
+class LRNOp : public XlaOpKernel {
+ public:
+ explicit LRNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_));
+
+ // TODO(phawkins): handle non-float types for attributes.
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape in_shape = ctx->InputShape(0);
+ OP_REQUIRES(ctx, in_shape.dims() == 4,
+ errors::InvalidArgument("in must be 4-dimensional"));
+
+ xla::ComputationBuilder* builder = ctx->builder();
+ xla::ComputationDataHandle input = ctx->Input(0);
+
+ // sqr_sum[a, b, c, d] =
+ // sum(input[a, b, c, d - depth_radius : d + depth_radius + 1] ** 2)
+ // output = input / (bias + alpha * sqr_sum) ** beta
+
+ // We use a window of depth_radius_ * 2 + 1, to account for the current
+ // element and a depth_radius_ on either side.
+ auto squared = builder->Mul(input, input);
+ auto sqr_sum = builder->ReduceWindow(
+ squared, XlaHelpers::Zero(builder, input_type(0)),
+ *ctx->GetOrCreateAdd(input_type(0)),
+ /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
+ /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
+
+ auto scale = builder->Pow(
+ builder->Add(builder->ConstantR0<float>(bias_),
+ builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum)),
+ builder->ConstantR0<float>(-beta_));
+
+ ctx->SetOutput(0, builder->Mul(input, scale));
+ }
+
+ private:
+ int64 depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+REGISTER_XLA_OP("LRN", LRNOp);
+
+class LRNGradOp : public XlaOpKernel {
+ public:
+ explicit LRNGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("depth_radius", &depth_radius_));
+
+ // TODO(phawkins): handle non-float types for attributes.
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("bias", &bias_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", &alpha_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", &beta_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape in_grads_shape = ctx->InputShape(0);
+ const TensorShape in_image_shape = ctx->InputShape(1);
+ const TensorShape out_image_shape = ctx->InputShape(2);
+
+ OP_REQUIRES(ctx, in_grads_shape.dims() == 4 && in_image_shape.dims() == 4,
+ errors::InvalidArgument("inputs must be 4-dimensional"));
+ const int64 batch = in_grads_shape.dim_size(0);
+ const int64 rows = in_grads_shape.dim_size(1);
+ const int64 cols = in_grads_shape.dim_size(2);
+ const int64 depth = in_grads_shape.dim_size(3);
+ OP_REQUIRES(
+ ctx, in_image_shape.dim_size(0) == batch &&
+ in_image_shape.dim_size(1) == rows &&
+ in_image_shape.dim_size(2) == cols &&
+ in_image_shape.dim_size(3) == depth &&
+ out_image_shape.dim_size(0) == batch &&
+ out_image_shape.dim_size(1) == rows &&
+ out_image_shape.dim_size(2) == cols &&
+ out_image_shape.dim_size(3) == depth,
+ errors::InvalidArgument(
+ "input_grads, input_image, and out_image should have the same "
+ "shape"));
+
+ xla::ComputationBuilder* builder = ctx->builder();
+ xla::ComputationDataHandle in_grads = ctx->Input(0);
+ xla::ComputationDataHandle in_image = ctx->Input(1);
+ xla::ComputationDataHandle out_image = ctx->Input(2);
+
+ // This code is ported from tensorflow/core/kernels/lrn_op.cc. In Python
+ // pseudo-code, the Eigen code does this for each spatial position:
+ // grads = [0.0] * depth
+ // for j in range(depth):
+ // depth_begin = max(0, j - depth_radius)
+ // depth_end = min(depth, j + depth_radius + 1)
+ //
+ // norm = 0
+ // for k in range(depth_begin, depth_end):
+ // norm += in_image[k] * in_image[k]
+ // norm = alpha * norm + bias
+ //
+ // for k in range(depth_begin, depth_end):
+ // dyi = -2.0 * alpha * beta * in_image[k] * out_image[j] / norm
+ // if k == j:
+ // dyi += norm ** (-beta)
+ // dyi *= out_grads[j]
+ // grads[k] += dyi
+
+ auto squared = builder->Mul(in_image, in_image);
+ auto sqr_sum = builder->ReduceWindow(
+ squared, XlaHelpers::Zero(builder, input_type(0)),
+ *ctx->GetOrCreateAdd(input_type(0)),
+ /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
+ /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
+
+ auto norm =
+ builder->Add(builder->ConstantR0<float>(bias_),
+ builder->Mul(builder->ConstantR0<float>(alpha_), sqr_sum));
+
+ auto dy = builder->Mul(
+ builder->Mul(builder->ConstantR0<float>(-2.0f * alpha_ * beta_),
+ builder->Div(out_image, norm)),
+ in_grads);
+
+ auto dy_reduced = builder->ReduceWindow(
+ dy, XlaHelpers::Zero(builder, input_type(0)),
+ *ctx->GetOrCreateAdd(input_type(0)),
+ /* window_dimensions = */ {1, 1, 1, depth_radius_ * 2 + 1},
+ /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kSame);
+
+ xla::ComputationDataHandle gradients = builder->Add(
+ builder->Mul(in_image, dy_reduced),
+ builder->Mul(in_grads,
+ builder->Pow(norm, builder->ConstantR0<float>(-beta_))));
+
+ ctx->SetOutput(0, gradients);
+ }
+
+ private:
+ int64 depth_radius_;
+ float bias_;
+ float alpha_;
+ float beta_;
+};
+
+REGISTER_XLA_OP("LRNGrad", LRNGradOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
new file mode 100644
index 0000000000..5af6a79f3e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -0,0 +1,88 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific MatMul Op.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+class MatMulOp : public XlaOpKernel {
+ public:
+ explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
+ if (is_sparse) {
+ // SparseMatMul is actually dense matmul with a hint that one or
+ // both of the inputs may contain a lot of zeroes. On CPU these
+ // inputs are dynamically converted to sparse representation
+ // before multiplication. For now in XLA we ignore the hints
+ // and always do dense multiplication.
+ bool dummy_is_sparse;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("a_is_sparse", &dummy_is_sparse));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("b_is_sparse", &dummy_is_sparse));
+ }
+ }
+
+ ~MatMulOp() override = default;
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape a_shape = ctx->InputShape(0);
+ const TensorShape b_shape = ctx->InputShape(1);
+
+ // Check that the dimensions of the two matrices are valid.
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_shape),
+ errors::InvalidArgument("In[0] is not a matrix"));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b_shape),
+ errors::InvalidArgument("In[1] is not a matrix"));
+ int first_index = transpose_a_ ? 0 : 1;
+ int second_index = transpose_b_ ? 1 : 0;
+
+ OP_REQUIRES(ctx,
+ a_shape.dim_size(first_index) == b_shape.dim_size(second_index),
+ errors::InvalidArgument("Matrix size-compatible: In[0]: ",
+ a_shape.DebugString(), ", In[1]: ",
+ b_shape.DebugString()));
+
+ xla::ComputationDataHandle a = ctx->Input(0);
+ xla::ComputationDataHandle b = ctx->Input(1);
+ auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a;
+ auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b;
+ ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs));
+ }
+
+ private:
+ bool transpose_a_;
+ bool transpose_b_;
+};
+
+REGISTER_XLA_OP("MatMul", MatMulOp);
+
+class SparseMatMulOp : public MatMulOp {
+ public:
+ explicit SparseMatMulOp(OpKernelConstruction* ctx) : MatMulOp(ctx, true) {}
+
+ ~SparseMatMulOp() override = default;
+};
+
+REGISTER_XLA_OP("SparseMatMul", SparseMatMulOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/no_op.cc b/tensorflow/compiler/tf2xla/kernels/no_op.cc
new file mode 100644
index 0000000000..806bfc604f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/no_op.cc
@@ -0,0 +1,24 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/no_op.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+REGISTER_XLA_OP("NoOp", NoOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pack_op.cc b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
new file mode 100644
index 0000000000..7456d92de0
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/pack_op.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA Pack operator.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+class PackOp : public XlaOpKernel {
+ public:
+ explicit PackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ std::vector<xla::ComputationDataHandle> values;
+ std::vector<TensorShape> shapes;
+ OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes));
+ const int num = values.size();
+
+ OP_REQUIRES(ctx, num >= 0,
+ errors::InvalidArgument("Pack requires >= 1 arguments"));
+
+ // Verify that all input shapes match
+ for (int i = 1; i < num; i++) {
+ OP_REQUIRES(ctx, shapes[0].IsSameSize(shapes[i]),
+ errors::InvalidArgument(
+ "Shapes of all inputs must match: values[0].shape = ",
+ shapes[0].DebugString(), " != values[", i, "].shape = ",
+ shapes[i].DebugString()));
+ }
+
+ int expanded_num_dims = shapes[0].dims() + 1;
+ int axis = axis_;
+ if (axis < 0) axis += expanded_num_dims;
+
+ OP_REQUIRES(ctx, 0 <= axis && axis < expanded_num_dims,
+ errors::InvalidArgument("axis = ", axis_, " not in [",
+ -expanded_num_dims, ", ",
+ expanded_num_dims, ")"));
+
+ std::vector<xla::ComputationDataHandle> reshaped_inputs(num);
+
+ TensorShape child_shape(shapes[0]);
+ child_shape.InsertDim(axis, 1);
+
+ for (int i = 0; i < num; ++i) {
+ // Reshape the inputs to have an extra dimension of size 1.
+ reshaped_inputs[i] =
+ ctx->builder()->Reshape(values[i], child_shape.dim_sizes());
+ }
+
+ ctx->SetOutput(0, ctx->builder()->ConcatInDim(reshaped_inputs, axis));
+ }
+
+ private:
+ int axis_;
+};
+
+REGISTER_XLA_OP("Pack", PackOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
new file mode 100644
index 0000000000..2846414c5e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+namespace {
+
+class PadOp : public XlaOpKernel {
+ public:
+ explicit PadOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape pad_shape = ctx->InputShape(1);
+ const int dims = input_shape.dims();
+ OP_REQUIRES(
+ ctx,
+ TensorShapeUtils::IsMatrix(pad_shape) && pad_shape.dim_size(1) == 2,
+ errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
+ pad_shape.DebugString()));
+ const int fixed_dims =
+ (allow_legacy_scalars() && dims == 0 && pad_shape.dim_size(0) == 1)
+ ? 1
+ : dims;
+ OP_REQUIRES(
+ ctx, fixed_dims == pad_shape.dim_size(0),
+ errors::InvalidArgument(
+ "The first dimension of paddings must be the rank of inputs",
+ pad_shape.DebugString(), " ", input_shape.DebugString()));
+
+ if (fixed_dims == 0) {
+ // Tensor is rank 0. Return it unchanged.
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ // Evaluate the 'padding' constant input, reshaping to a matrix.
+ xla::Literal pad_literal;
+ OP_REQUIRES_OK(
+ ctx, ctx->ConstantInputReshaped(1, {fixed_dims, 2}, &pad_literal));
+
+ xla::PaddingConfig config;
+ for (int i = 0; i < fixed_dims; ++i) {
+ auto* dim = config.add_dimensions();
+ int before = xla::LiteralUtil::Get<int32>(pad_literal, {i, 0});
+ int after = xla::LiteralUtil::Get<int32>(pad_literal, {i, 1});
+ OP_REQUIRES(ctx, before >= 0 && after >= 0,
+ errors::InvalidArgument("Paddings must be non-negative: ",
+ before, " ", after));
+ dim->set_edge_padding_low(before);
+ dim->set_edge_padding_high(after);
+ }
+
+ auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
+ ctx->SetOutput(0, ctx->builder()->Pad(ctx->Input(0), zero, config));
+ }
+};
+
+REGISTER_XLA_OP("Pad", PadOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
new file mode 100644
index 0000000000..7a1ce2db85
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc
@@ -0,0 +1,374 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA specific pooling ops.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
+#include "tensorflow/core/kernels/pooling_ops_common.h"
+
+namespace tensorflow {
+namespace {
+
+// Superclass of pooling ops.
+class PoolingOp : public XlaOpKernel {
+ public:
+ explicit PoolingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ // Data format doesn't matter since the kernel is specified explicitly.
+ std::vector<int32> ksize_int;
+ std::vector<int32> stride_int;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
+ OP_REQUIRES(ctx, ksize_int.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
+ OP_REQUIRES(ctx, stride_int.size() == 4,
+ errors::InvalidArgument("Sliding window stride field must "
+ "specify 4 dimensions"));
+ for (int i = 0; i < 4; ++i) {
+ ksize_.push_back(ksize_int[i]);
+ stride_.push_back(stride_int[i]);
+ }
+ Padding padding;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
+ padding_ = (padding == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+ }
+
+ // Method that builds an initial value to use in reductions.
+ virtual xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
+ DataType data_type) = 0;
+
+ // The reduction operation to apply to each window.
+ virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx,
+ DataType dtype) = 0;
+
+ // A post-processing operation to apply on the outputs of the ReduceWindow.
+ virtual xla::ComputationDataHandle PostProcessOutput(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
+ DataType dtype, const TensorShape& input_shape) = 0;
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationDataHandle input = ctx->Input(0);
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ const DataType type = input_type(0);
+ xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
+ input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_,
+ stride_, padding_);
+ ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
+ }
+
+ protected:
+ std::vector<int64> ksize_;
+ std::vector<int64> stride_;
+ xla::Padding padding_;
+};
+
+class MaxPoolOp : public PoolingOp {
+ public:
+ explicit MaxPoolOp(OpKernelConstruction* ctx) : PoolingOp(ctx) {}
+
+ xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
+ DataType data_type) override {
+ return XlaHelpers::MinValue(b, data_type);
+ }
+
+ const xla::Computation* Reduction(XlaOpKernelContext* ctx,
+ DataType dtype) override {
+ return ctx->GetOrCreateMax(dtype);
+ }
+
+ xla::ComputationDataHandle PostProcessOutput(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
+ DataType dtype, const TensorShape& input_shape) override {
+ return output;
+ }
+};
+
+REGISTER_XLA_OP("MaxPool", MaxPoolOp);
+
+// Common computation shared between AvgPool and AvgPoolGrad. Divide each
+// element of an image by the count of elements that contributed to that
+// element during pooling.
+static xla::ComputationDataHandle AvgPoolDivideByCount(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
+ DataType dtype, const TensorShape& input_shape, xla::Padding padding,
+ const std::vector<int64>& ksize, const std::vector<int64>& stride,
+ TensorFormat data_format) {
+ if (padding == xla::Padding::kValid) {
+ // In VALID padding, all windows have the same number of elements
+ // contributing to each average. Divide by the window size everywhere to
+ // get the average.
+ int64 window_size = std::accumulate(ksize.begin(), ksize.end(), 1,
+ [](int64 a, int64 b) { return a * b; });
+
+ auto divisor =
+ XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
+ return ctx->builder()->Div(output, divisor);
+ } else {
+ // For SAME padding, the padding shouldn't be included in the
+ // counts. We use another ReduceWindow to find the right counts.
+
+ // TODO(phawkins): use a less brute-force way to compute this. Only
+ // the boundary regions will have interesting values here.
+
+ int height_dim = GetTensorDimIndex(data_format, 'H');
+ int width_dim = GetTensorDimIndex(data_format, 'W');
+ CHECK_LT(height_dim, width_dim);
+
+ // Build a matrix of all 1s, with the same width/height as the input.
+ auto ones = ctx->builder()->Broadcast(
+ XlaHelpers::One(ctx->builder(), dtype),
+ {input_shape.dim_size(height_dim), input_shape.dim_size(width_dim)});
+
+ // Perform a ReduceWindow with the same window size, strides, and padding
+ // to count the number of contributions to each result element.
+ auto counts = ctx->builder()->ReduceWindow(
+ ones, XlaHelpers::Zero(ctx->builder(), dtype),
+ *ctx->GetOrCreateAdd(dtype), {ksize[height_dim], ksize[width_dim]},
+ {stride[height_dim], stride[width_dim]}, xla::Padding::kSame);
+
+ return ctx->builder()->Div(output, counts, {height_dim, width_dim});
+ }
+}
+
+class AvgPoolOp : public PoolingOp {
+ public:
+ explicit AvgPoolOp(OpKernelConstruction* ctx) : PoolingOp(ctx) {
+ string data_format_str;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
+ OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ }
+
+ xla::ComputationDataHandle InitValue(xla::ComputationBuilder* b,
+ DataType data_type) override {
+ return XlaHelpers::Zero(b, data_type);
+ }
+
+ const xla::Computation* Reduction(XlaOpKernelContext* ctx,
+ DataType dtype) override {
+ return ctx->GetOrCreateAdd(dtype);
+ }
+
+ xla::ComputationDataHandle PostProcessOutput(
+ XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
+ DataType dtype, const TensorShape& input_shape) override {
+ return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
+ ksize_, stride_, data_format_);
+ }
+
+ private:
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("AvgPool", AvgPoolOp);
+
+// The operation to compute MaxPool gradients.
+// It takes three inputs:
+// - The original input tensor
+// - The original output tensor
+// - Backprop tensor for output
+// It produces one output: backprop tensor for input.
+class MaxPoolGradOp : public XlaOpKernel {
+ public:
+ explicit MaxPoolGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(ctx, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
+ OP_REQUIRES(ctx, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape tensor_in_shape = ctx->InputShape(0);
+ const TensorShape tensor_out_shape = ctx->InputShape(1);
+ const TensorShape out_backprop_shape = ctx->InputShape(2);
+
+ // For maxpooling, tensor_in should have 4 dimensions.
+ OP_REQUIRES(ctx, tensor_in_shape.dims() == 4,
+ errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ OP_REQUIRES(ctx, tensor_out_shape.dims() == 4,
+ errors::InvalidArgument("tensor_out must be 4-dimensional"));
+ // For maxpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(ctx, out_backprop_shape.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+
+ // TODO(phawkins): The XLA version doesn't need tensor_out. Investigate
+ // whether this is a good time/space tradeoff.
+ auto input = ctx->Input(0);
+ auto out_backprop = ctx->Input(2);
+
+ xla::Padding xla_padding =
+ (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ xla::PrimitiveType element_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
+ xla::ComputationDataHandle init_value =
+ XlaHelpers::Zero(ctx->builder(), input_type(2));
+ auto select = CreateScalarGeComputation(element_type, ctx->builder());
+ auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
+ xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter(
+ input, select, ksize_, stride_, xla_padding, out_backprop, init_value,
+ scatter);
+
+ ctx->SetOutput(0, gradients);
+ }
+
+ private:
+ std::vector<int64> ksize_;
+ std::vector<int64> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("MaxPoolGrad", MaxPoolGradOp);
+
+// Average-pooling gradient
+class AvgPoolGradOp : public XlaOpKernel {
+ public:
+ explicit AvgPoolGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ string data_format;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
+ OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
+ OP_REQUIRES(ctx, ksize_.size() == 4,
+ errors::InvalidArgument("Sliding window ksize field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
+ OP_REQUIRES(ctx, stride_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
+ OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
+ errors::Unimplemented(
+ "Pooling is not yet supported on the batch dimension."));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape gradients_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
+
+ const TensorShape out_backprop_shape = ctx->InputShape(1);
+
+ // For avgpooling, tensor_in_shape should have 1 dimension, and 4 elements.
+ OP_REQUIRES(
+ ctx, gradients_shape.dims() == 4,
+ errors::InvalidArgument("orig_input_shape must have 4 elements"));
+
+ // For avgpooling, out_backprop should have 4 dimensions.
+ OP_REQUIRES(ctx, out_backprop_shape.dims() == 4,
+ errors::InvalidArgument("out_backprop must be 4-dimensional"));
+
+ int height_dim = GetTensorDimIndex(data_format_, 'H');
+ int width_dim = GetTensorDimIndex(data_format_, 'W');
+ int depth = GetTensorDim(out_backprop_shape, data_format_, 'C');
+
+ // We can think of average-pooling as:
+ // * a convolution with a kernel consisting entirely of 1s, where the
+ // input feature and output feature are equal, and 0s everywhere else.
+ // * followed by dividing by the counts.
+ //
+ // This then gives us an algorithm to build the gradient:
+ // * divide out_backprop by the counts, followed by
+ // * Conv2DBackpropInput specialized for that kernel, which simplifies to
+ // a Pad and a ReduceWindow.
+ //
+ // For an explanation of backpropagation for convolution, see the comments
+ // in third_party/tensorflow/core/kernels/conv_grad_ops.h
+
+ // TF filter shape is [ H, W, inC, outC ]
+ TensorShape filter_shape(
+ {ksize_[height_dim], ksize_[width_dim], depth, depth});
+
+ // Reuse the logic from Conv2DBackpropInput to compute padding.
+ Conv2DBackpropDimensions dims;
+ OP_REQUIRES_OK(
+ ctx, Conv2DBackpropComputeDimensions(
+ "AvgPoolGrad", gradients_shape, filter_shape,
+ out_backprop_shape, stride_, padding_, data_format_, &dims));
+
+ auto out_backprop = ctx->Input(1);
+
+ // The input gradients are computed by a convolution of the output
+ // gradients
+ // and the filter, with some appropriate padding. See the comment at
+ // the top of conv_grad_ops.h for details.
+ DataType dtype = input_type(1);
+
+ xla::Padding xla_padding =
+ (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
+
+ // Divide the out_backprop values by the counts for each spatial position.
+ std::vector<int64> stride_int64s(stride_.begin(), stride_.end());
+ auto out_backprop_div =
+ AvgPoolDivideByCount(ctx, out_backprop, dtype, gradients_shape,
+ xla_padding, ksize_, stride_int64s, data_format_);
+
+ // Pad the gradients in the spatial dimensions. We use the same padding
+ // as Conv2DBackpropInput.
+ xla::PaddingConfig padding_config = xla::MakeNoPaddingConfig(4);
+ auto* row_padding = padding_config.mutable_dimensions(height_dim);
+ row_padding->set_edge_padding_low(dims.rows.pad_before);
+ row_padding->set_edge_padding_high(dims.rows.pad_after);
+ row_padding->set_interior_padding(dims.rows.stride - 1);
+
+ auto* col_padding = padding_config.mutable_dimensions(width_dim);
+ col_padding->set_edge_padding_low(dims.cols.pad_before);
+ col_padding->set_edge_padding_high(dims.cols.pad_after);
+ col_padding->set_interior_padding(dims.cols.stride - 1);
+
+ auto zero = XlaHelpers::Zero(ctx->builder(), dtype);
+ auto padded_gradients =
+ ctx->builder()->Pad(out_backprop_div, zero, padding_config);
+
+ // in_backprop = padded_gradients <conv> ones
+ xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow(
+ padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_,
+ /* window_strides = */ {1, 1, 1, 1}, xla::Padding::kValid);
+
+ ctx->SetOutput(0, in_backprop);
+ }
+
+ private:
+ std::vector<int64> ksize_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+};
+
+REGISTER_XLA_OP("AvgPoolGrad", AvgPoolGradOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
new file mode 100644
index 0000000000..4ffe278d1c
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA implementations of Random ops
+// TODO(misard,phawkins): handle random number generator seeds/states correctly.
+// TODO(misard,phawkins): add tests.
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class RandomUniformOp : public XlaOpKernel {
+ public:
+ explicit RandomUniformOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
+
+ const DataType dtype = output_type(0);
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
+
+ xla::ComputationBuilder* b = ctx->builder();
+ xla::ComputationDataHandle result = b->RngUniform(
+ XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
+
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp);
+};
+
+REGISTER_XLA_OP("RandomUniform", RandomUniformOp);
+
+class RandomUniformIntOp : public XlaOpKernel {
+ public:
+ explicit RandomUniformIntOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx,
+ TensorShapeToXLAShape(input_type(1), shape, &xla_shape));
+
+ const TensorShape minval_shape = ctx->InputShape(1);
+ const TensorShape maxval_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(minval_shape),
+ errors::InvalidArgument("minval must be 0-D, got shape ",
+ minval_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(maxval_shape),
+ errors::InvalidArgument("maxval must be 0-D, got shape ",
+ maxval_shape.DebugString()));
+
+ auto minval = ctx->Input(1);
+ auto maxval = ctx->Input(2);
+ ctx->SetOutput(0, ctx->builder()->RngUniform(minval, maxval, xla_shape));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp);
+};
+
+REGISTER_XLA_OP("RandomUniformInt", RandomUniformIntOp);
+
+class RandomStandardNormalOp : public XlaOpKernel {
+ public:
+ explicit RandomStandardNormalOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const DataType dtype = output_type(0);
+
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape));
+ xla::Shape xla_shape;
+ OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape));
+
+ xla::ComputationBuilder* b = ctx->builder();
+
+ // Normal distribution with a mean of 0 and a standard deviation of 1:
+ xla::ComputationDataHandle result = b->RngNormal(
+ XlaHelpers::Zero(b, dtype), XlaHelpers::One(b, dtype), xla_shape);
+
+ ctx->SetOutput(0, result);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp);
+};
+
+REGISTER_XLA_OP("RandomStandardNormal", RandomStandardNormalOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
new file mode 100644
index 0000000000..ac929af2e2
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc
@@ -0,0 +1,157 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific reduction Ops.
+
+#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+class SumOp : public XlaReductionOp {
+ public:
+ explicit SumOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->Add(scalar_lhs, scalar_rhs);
+ }
+};
+
+REGISTER_XLA_OP("Sum", SumOp);
+
+class ProdOp : public XlaReductionOp {
+ public:
+ explicit ProdOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+
+ xla::ComputationDataHandle InitialValue(
+ xla::ComputationBuilder* builder) override {
+ return XlaHelpers::One(builder, input_type(0));
+ }
+
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->Mul(scalar_lhs, scalar_rhs);
+ }
+};
+
+REGISTER_XLA_OP("Prod", ProdOp);
+
+class MinOp : public XlaReductionOp {
+ public:
+ explicit MinOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+
+ xla::ComputationDataHandle InitialValue(
+ xla::ComputationBuilder* builder) override {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
+ return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
+ }
+
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->Min(scalar_lhs, scalar_rhs);
+ }
+};
+
+REGISTER_XLA_OP("Min", MinOp);
+
+class MaxOp : public XlaReductionOp {
+ public:
+ explicit MaxOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+
+ xla::ComputationDataHandle InitialValue(
+ xla::ComputationBuilder* builder) override {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
+ return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type));
+ }
+
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->Max(scalar_lhs, scalar_rhs);
+ }
+};
+
+REGISTER_XLA_OP("Max", MaxOp);
+
+class MeanOp : public XlaReductionOp {
+ public:
+ explicit MeanOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->Add(scalar_lhs, scalar_rhs);
+ }
+
+ bool BuildFinalizer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_argument,
+ int64 num_elements_reduced) override {
+ auto divisor = XlaHelpers::IntegerLiteral(builder, input_type(0),
+ num_elements_reduced);
+ builder->Div(scalar_argument, divisor);
+ return true;
+ }
+};
+
+REGISTER_XLA_OP("Mean", MeanOp);
+
+class AllOp : public XlaReductionOp {
+ public:
+ explicit AllOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+
+ xla::ComputationDataHandle InitialValue(
+ xla::ComputationBuilder* builder) override {
+ return builder->ConstantR0<bool>(true);
+ }
+
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->LogicalAnd(scalar_lhs, scalar_rhs);
+ }
+};
+
+REGISTER_XLA_OP("All", AllOp);
+
+class AnyOp : public XlaReductionOp {
+ public:
+ explicit AnyOp(OpKernelConstruction* ctx) : XlaReductionOp(ctx) {}
+
+ xla::ComputationDataHandle InitialValue(
+ xla::ComputationBuilder* builder) override {
+ return builder->ConstantR0<bool>(false);
+ }
+
+ void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) override {
+ builder->LogicalOr(scalar_lhs, scalar_rhs);
+ }
+};
+
+REGISTER_XLA_OP("Any", AnyOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.h b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
new file mode 100644
index 0000000000..7f0dd26f91
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.h
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific base classes for Reduction Ops.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+// Reduction operations. The base class contains pure virtual methods
+// to override: description is a textual description of the mapped
+// function; InitialValue constructs the base case for the reduction;
+// BuildReducer adds the implementation of the reduction lambda to a
+// xla::ComputationBuilder and BuildFinalizer adds the
+// implementation of the finalizer lambda (if there is one) to a
+// xla::ComputationBuilder.
+class XlaReductionOp : public XlaOpKernel {
+ public:
+ explicit XlaReductionOp(OpKernelConstruction* ctx);
+ ~XlaReductionOp() override {}
+
+ // Return the base case for the reduction. Defaults to zero.
+ virtual xla::ComputationDataHandle InitialValue(
+ xla::ComputationBuilder* builder);
+
+ // Implement the (scalar,scalar)->scalar lambda that should be
+ // applied to each pair of elements to be reduced. The desired
+ // computation should be added to 'builder' and
+ // '(scalar_lhs,scalar_rhs)' are the function's inputs.
+ virtual void BuildReducer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_lhs,
+ const xla::ComputationDataHandle& scalar_rhs) = 0;
+
+ // Implement the scalar->scalar lambda that should be applied to
+ // each element to be finalized. The desired computation should be
+ // added to 'builder' and 'scalar_argument' is the function's
+ // input. 'num_elements_reduced' is the number of elements that contributed
+ // to the reduction. If the reduction has a finalizer return true, otherwise
+ // return false and any computation added to builder will be
+ // ignored. Defaults to return false.
+ virtual bool BuildFinalizer(xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_argument,
+ int64 num_elements_reduced);
+
+ void Compile(XlaOpKernelContext* ctx) override;
+
+ private:
+ // True if the number of dimensions should be maintained.
+ bool keep_dims_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_REDUCTION_OPS_H_
diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
new file mode 100644
index 0000000000..d6b085e897
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific reduction Ops.
+
+#include "tensorflow/compiler/tf2xla/kernels/reduction_ops.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+
+XlaReductionOp::XlaReductionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ const DataType dt = BaseType(input_type(0));
+ OP_REQUIRES_OK(ctx, ctx->MatchSignature({dt, DT_INT32}, {dt}));
+
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("keep_dims", &keep_dims_));
+}
+
+// Return the base case for the reduction. Defaults to zero.
+xla::ComputationDataHandle XlaReductionOp::InitialValue(
+ xla::ComputationBuilder* builder) {
+ return XlaHelpers::Zero(builder, input_type(0));
+}
+
+// Unless BuildFinalizer is overridden the reduction has no
+// finalizer.
+bool XlaReductionOp::BuildFinalizer(
+ xla::ComputationBuilder* builder,
+ const xla::ComputationDataHandle& scalar_argument,
+ int64 num_elements_reduced) {
+ return false;
+}
+
+void XlaReductionOp::Compile(XlaOpKernelContext* ctx) {
+ const TensorShape data_shape = ctx->InputShape(0);
+ const TensorShape axes_tensor_shape = ctx->InputShape(1);
+ VLOG(1) << "ReductionOp: " << ctx->op_kernel().name();
+
+ if (axes_tensor_shape.num_elements() == 0) {
+ // The reduction axes is an empty vector, which means there are no
+ // axes to reduce so just pass the input directly through to the
+ // output.
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ // Evaluate the constant, reshaping to a 1-vector if it is a scalar.
+ xla::Literal axes_literal;
+ OP_REQUIRES_OK(ctx,
+ ctx->ConstantInputReshaped(
+ 1, {axes_tensor_shape.num_elements()}, &axes_literal));
+
+ VLOG(1) << "data shape: " << data_shape.DebugString();
+ VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal);
+
+ gtl::InlinedVector<bool, 4> bitmap(data_shape.dims(), false);
+ std::vector<int64> xla_axes;
+ int64 num_elements_reduced = 1LL;
+ for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) {
+ int32 index = xla::LiteralUtil::Get<int>(axes_literal, {i});
+ OP_REQUIRES(ctx,
+ !(index < -data_shape.dims() || index >= data_shape.dims()),
+ errors::InvalidArgument("Invalid reduction dimension (", index,
+ " for input with ", data_shape.dims(),
+ " dimension(s)"));
+ index = (index + data_shape.dims()) % data_shape.dims();
+ bitmap[index] = true;
+ xla_axes.push_back(index);
+ num_elements_reduced *= data_shape.dim_size(index);
+ }
+
+ std::vector<int64> final_shape;
+ for (int i = 0; i < data_shape.dims(); ++i) {
+ if (!bitmap[i]) {
+ // If we are not reducing along dimension i.
+ int64 dim = data_shape.dim_size(i);
+ final_shape.push_back(dim);
+ } else if (keep_dims_) {
+ // We are reducing along dimension i, but we want to keep the
+ // same number of dimensions, so we set the dimension of i to
+ // '1'.
+ final_shape.push_back(1);
+ }
+ }
+
+ string desc = ctx->op_kernel().name();
+
+ // Call virtual method to get the initial value.
+ const xla::ComputationDataHandle initial = InitialValue(ctx->builder());
+ // Construct the builder for the reduction lambda.
+ xla::ComputationBuilder r(ctx->builder()->client(),
+ strings::StrCat(desc, "-reduction"));
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type));
+ // Make two scalar parameters of the desired type for the lambda.
+ xla::ComputationDataHandle rx =
+ r.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ xla::ComputationDataHandle ry =
+ r.Parameter(1, xla::ShapeUtil::MakeShape(type, {}), "y");
+
+ auto data = ctx->Input(0);
+
+ // Call virtual method to build the reduction lambda.
+ BuildReducer(&r, rx, ry);
+ xla::Computation reduction_computation = r.Build().ConsumeValueOrDie();
+ xla::ComputationDataHandle reduce =
+ ctx->builder()->Reduce(data, initial, reduction_computation, xla_axes);
+
+ // Construct the builder for the finalizer lambda.
+ xla::ComputationBuilder f(ctx->builder()->client(),
+ strings::StrCat(desc, "-finalizer"));
+ // Make the scalar parameter of the desired type for the lambda.
+ xla::ComputationDataHandle fx =
+ f.Parameter(0, xla::ShapeUtil::MakeShape(type, {}), "x");
+ // Call virtual method to build the finalizer lambda.
+ bool has_finalizer = BuildFinalizer(&f, fx, num_elements_reduced);
+ xla::Computation finalizer_computation = f.Build().ConsumeValueOrDie();
+ xla::ComputationDataHandle pre_reshaped_data;
+ if (has_finalizer) {
+ // This reduction Op includes a finalizer so run it as a Map.
+ pre_reshaped_data = ctx->builder()->Map({reduce}, finalizer_computation);
+ } else {
+ pre_reshaped_data = reduce;
+ }
+
+ xla::ComputationDataHandle result;
+ if (keep_dims_) {
+ result = ctx->builder()->Reshape(pre_reshaped_data, final_shape);
+ } else {
+ result = pre_reshaped_data;
+ }
+ ctx->SetOutput(0, result);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
new file mode 100644
index 0000000000..3cddff9df4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Native XLA implementations of XLA Relu Ops
+
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/no_op.h"
+
+namespace tensorflow {
+namespace {
+
+class ReluOp : public XlaOpKernel {
+ public:
+ explicit ReluOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Computes the max of the scalar input x and 0.
+ void Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* builder = ctx->builder();
+ auto zero = XlaHelpers::Zero(builder, input_type(0));
+ ctx->SetOutput(0, builder->Max(zero, ctx->Input(0)));
+ }
+};
+
+class Relu6Op : public XlaOpKernel {
+ public:
+ explicit Relu6Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ // Clamp the scalar input between 0 and 6.
+ void Compile(XlaOpKernelContext* ctx) {
+ xla::ComputationBuilder* builder = ctx->builder();
+ auto zero = XlaHelpers::Zero(builder, input_type(0));
+ auto six = XlaHelpers::IntegerLiteral(builder, input_type(0), 6);
+ ctx->SetOutput(0, builder->Clamp(zero, ctx->Input(0), six));
+ }
+};
+
+// A subclass of a XlaBinaryMapOp must build the lambda computation
+// that describes the (scalar,scalar)->scalar function to apply to
+// each element of the input. We have to use XlaBinaryMapOp instead of
+// XlaBinaryOp here because XLA Select does not do automatic
+// broadcasting.
+class ReluGradOp : public XlaBinaryMapOp {
+ public:
+ explicit ReluGradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {}
+ // Return the lhs (incoming gradient) if the rhs (input feature) > 0,
+ // otherwise return 0.
+ void BuildMapLambda(xla::ComputationBuilder* b,
+ const xla::ComputationDataHandle& gradient,
+ const xla::ComputationDataHandle& feature) override {
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ b->Select(b->Gt(feature, zero), gradient, zero);
+ }
+};
+
+class Relu6GradOp : public XlaBinaryMapOp {
+ public:
+ explicit Relu6GradOp(OpKernelConstruction* ctx) : XlaBinaryMapOp(ctx) {}
+ // Return the lhs (incoming gradient) if the rhs (input feature) > 0,
+ // otherwise return 0.
+ void BuildMapLambda(xla::ComputationBuilder* b,
+ const xla::ComputationDataHandle& gradient,
+ const xla::ComputationDataHandle& feature) override {
+ const auto zero = XlaHelpers::Zero(b, input_type(0));
+ auto six = XlaHelpers::IntegerLiteral(b, input_type(0), 6);
+ b->Select(b->LogicalAnd(b->Lt(feature, six), b->Gt(feature, zero)),
+ gradient, zero);
+ }
+};
+
+REGISTER_XLA_OP("Relu", ReluOp);
+REGISTER_XLA_OP("Relu6", Relu6Op);
+REGISTER_XLA_OP("ReluGrad", ReluGradOp);
+REGISTER_XLA_OP("Relu6Grad", Relu6GradOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
new file mode 100644
index 0000000000..febce0e126
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc
@@ -0,0 +1,101 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific reshape Op.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace {
+
+class ReshapeOp : public XlaOpKernel {
+ public:
+ explicit ReshapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape sizes_shape = ctx->InputShape(1);
+ // Preliminary validation of sizes.
+ OP_REQUIRES(ctx, IsLegacyVector(sizes_shape),
+ errors::InvalidArgument("sizes input must be 1-D, not shape ",
+ sizes_shape.DebugString()));
+ const int64 num_dims = sizes_shape.num_elements();
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
+
+ // Compute the output shape. Determine product of specified
+ // dimensions, and find the index of the unspecified one if there
+ // is one.
+ TensorShape shape;
+ int64 product = 1;
+ int unknown_index = -1;
+ for (int d = 0; d < num_dims; ++d) {
+ const int32 size = xla::LiteralUtil::Get<int>(literal, {d});
+ if (size == -1) {
+ OP_REQUIRES(
+ ctx, unknown_index == -1,
+ errors::InvalidArgument("only one input size may be -1, not both ",
+ unknown_index, " and ", d));
+ unknown_index = d;
+ shape.AddDim(1);
+ } else {
+ OP_REQUIRES(ctx, size >= 0,
+ errors::InvalidArgument(
+ "size ", d, " must be non-negative, not ", size));
+ shape.AddDim(size);
+ product *= size;
+ }
+ }
+ if (unknown_index != -1) {
+ OP_REQUIRES(
+ ctx, product > 0,
+ errors::InvalidArgument("Reshape cannot infer the missing input size "
+ "for an empty tensor unless all specified "
+ "input sizes are non-zero"));
+ const int64 missing = input_shape.num_elements() / product;
+ OP_REQUIRES(
+ ctx, product * missing == input_shape.num_elements(),
+ errors::InvalidArgument(
+ "Input to reshape is a tensor with ", input_shape.num_elements(),
+ " values, but the requested shape requires a multiple of ",
+ product));
+ shape.set_dim(unknown_index, missing);
+ }
+ OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
+ errors::InvalidArgument("Input to reshape is a tensor with ",
+ input_shape.num_elements(),
+ " values, but the requested shape has ",
+ shape.num_elements()));
+
+ VLOG(1) << "Reshape " << input_shape.DebugString() << " "
+ << shape.DebugString();
+
+ ctx->SetOutput(0,
+ ctx->builder()->Reshape(ctx->Input(0), shape.dim_sizes()));
+ }
+};
+
+REGISTER_XLA_OP("Reshape", ReshapeOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
new file mode 100644
index 0000000000..87d11a38d4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -0,0 +1,79 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+// This TensorFlow op indicates that its input should be treated as a
+// specific return value from a function.
+class RetvalOp : public XlaOpKernel {
+ public:
+ explicit RetvalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const Tensor& input = ctx->op_kernel_context()->input(0);
+
+ OP_REQUIRES(ctx, input.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(input.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ auto frame = ctx->call_frame();
+ if (frame) {
+ // If 'frame' is non-null, this is an inner function call inside a JIT
+ // compilation.
+ frame->SetRetval(index_, input);
+ } else {
+ xla::ComputationDataHandle input = ctx->Input(0);
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ auto is_constant = ctx->builder()->IsConstant(input);
+ if (!is_constant.ok()) {
+ ctx->SetStatus(is_constant.status());
+ return;
+ }
+
+ XlaContext& tc = XlaContext::Get(ctx);
+ if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) {
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
+ tc.AddConstRetval(index_, dtype_, literal);
+ } else {
+ tc.AddRetval(index_, input);
+ }
+ }
+ }
+
+ private:
+ // The index of this return value in the returned tuple.
+ int index_;
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
+};
+
+REGISTER_XLA_OP("_Retval", RetvalOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/select_op.cc b/tensorflow/compiler/tf2xla/kernels/select_op.cc
new file mode 100644
index 0000000000..0fecc338ca
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/select_op.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+namespace {
+
+class SelectOp : public XlaOpKernel {
+ public:
+ explicit SelectOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape cond_shape = ctx->InputShape(0);
+ const TensorShape then_shape = ctx->InputShape(1);
+ const TensorShape else_shape = ctx->InputShape(2);
+
+ OP_REQUIRES(
+ ctx, then_shape.IsSameSize(else_shape),
+ errors::InvalidArgument(
+ "'then' and 'else' must have the same size. but received: ",
+ then_shape.DebugString(), " vs. ", else_shape.DebugString()));
+
+ bool broadcasting = !cond_shape.IsSameSize(then_shape);
+ if (broadcasting) {
+ OP_REQUIRES(
+ ctx, TensorShapeUtils::IsVector(cond_shape),
+ errors::InvalidArgument("'cond' must be a vector, but saw shape: ",
+ cond_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then_shape),
+ errors::InvalidArgument(
+ "'then' must be at least a vector, but saw shape: ",
+ then_shape.DebugString()));
+ OP_REQUIRES(ctx, then_shape.dim_size(0) == cond_shape.num_elements(),
+ errors::InvalidArgument("Number of batches of 'then' must "
+ "match size of 'cond', but saw: ",
+ then_shape.dim_size(0), " vs. ",
+ cond_shape.num_elements()));
+ }
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ auto cond_handle = ctx->Input(0);
+ auto then_handle = ctx->Input(1);
+ auto else_handle = ctx->Input(2);
+
+ if (broadcasting) {
+ // TODO(phawkins): broadcasting on the right seems pretty awkward in
+ // XLA. It seems we have to broadcast on the left and then Reshape
+ // to get the dimensions in the right order.
+ const auto dim_sizes = then_shape.dim_sizes();
+ gtl::ArraySlice<int64> bdims = dim_sizes;
+ bdims.pop_front();
+ cond_handle = builder->Broadcast(cond_handle, bdims);
+
+ std::vector<int64> dim_order(then_shape.dims());
+ dim_order[0] = then_shape.dims() - 1;
+ std::iota(dim_order.begin() + 1, dim_order.end(), 0);
+ cond_handle = builder->Transpose(cond_handle, dim_order);
+ }
+ ctx->SetOutput(0, builder->Select(cond_handle, then_handle, else_handle));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(SelectOp);
+};
+
+REGISTER_XLA_OP("Select", SelectOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
new file mode 100644
index 0000000000..42ae978c3c
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc
@@ -0,0 +1,213 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific sequence and range Ops.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+
+namespace tensorflow {
+namespace {
+
+template <typename T>
+Status GetValue(int index, XlaOpKernelContext* ctx, T* value) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
+ *value = xla::LiteralUtil::Get<T>(literal, {});
+ return Status::OK();
+}
+
+Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal));
+ switch (literal.shape().element_type()) {
+ case xla::S32:
+ *value = xla::LiteralUtil::Get<int32>(literal, {});
+ break;
+ case xla::S64:
+ *value = xla::LiteralUtil::Get<int64>(literal, {});
+ break;
+ default:
+ return errors::InvalidArgument("Invalid argument type for argument",
+ index);
+ }
+ return Status::OK();
+}
+
+// The type-specific part of the implementation of Range.
+template <typename T>
+Status CreateRangeTensor(const xla::Literal& start_literal,
+ const xla::Literal& limit_literal,
+ const xla::Literal& delta_literal, Tensor* output) {
+ T start = xla::LiteralUtil::Get<T>(start_literal, {});
+ T limit = xla::LiteralUtil::Get<T>(limit_literal, {});
+ T delta = xla::LiteralUtil::Get<T>(delta_literal, {});
+
+ if (delta == 0) {
+ return errors::InvalidArgument("Requires delta != 0: ", delta);
+ }
+ if (delta > 0) {
+ if (start > limit) {
+ return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
+ start, "/", limit);
+ }
+ } else {
+ if (start < limit) {
+ return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
+ start, "/", limit);
+ }
+ }
+ int64 size =
+ (std::is_integral<T>::value
+ ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
+ : std::ceil(std::abs((limit - start) / delta)));
+
+ *output = Tensor(DataTypeToEnum<T>::v(), TensorShape({size}));
+ auto flat = output->flat<T>();
+ T val = start;
+ for (int64 i = 0; i < size; ++i) {
+ flat(i) = val;
+ val += delta;
+ }
+ return Status::OK();
+}
+
+class RangeOp : public XlaOpKernel {
+ public:
+ explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape start_in_shape = ctx->InputShape(0);
+ const TensorShape limit_in_shape = ctx->InputShape(1);
+ const TensorShape delta_in_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, IsLegacyScalar(start_in_shape),
+ errors::InvalidArgument("start must be a scalar, not shape ",
+ start_in_shape.DebugString()));
+ OP_REQUIRES(ctx, IsLegacyScalar(limit_in_shape),
+ errors::InvalidArgument("limit must be a scalar, not shape ",
+ limit_in_shape.DebugString()));
+ OP_REQUIRES(ctx, IsLegacyScalar(delta_in_shape),
+ errors::InvalidArgument("delta must be a scalar, not shape ",
+ delta_in_shape.DebugString()));
+ xla::Literal start, limit, delta;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
+
+ DataType type = input_type(0);
+ Tensor output;
+ Status status;
+ switch (type) {
+ case DT_INT32:
+ status = CreateRangeTensor<int32>(start, limit, delta, &output);
+ break;
+ case DT_INT64:
+ status = CreateRangeTensor<int64>(start, limit, delta, &output);
+ break;
+ case DT_FLOAT:
+ status = CreateRangeTensor<float>(start, limit, delta, &output);
+ break;
+ case DT_DOUBLE:
+ status = CreateRangeTensor<double>(start, limit, delta, &output);
+ break;
+ default:
+ status = errors::InvalidArgument("Invalid type for Range ",
+ DataTypeString(type));
+ }
+ OP_REQUIRES_OK(ctx, status);
+ ctx->SetConstantOutput(0, output);
+ }
+};
+
+REGISTER_XLA_OP("Range", RangeOp);
+
+class LinSpaceOp : public XlaOpKernel {
+ public:
+ explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape start_in_shape = ctx->InputShape(0);
+ const TensorShape stop_in_shape = ctx->InputShape(1);
+ const TensorShape num_in_shape = ctx->InputShape(2);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
+ errors::InvalidArgument("start must be a scalar, not shape ",
+ start_in_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape),
+ errors::InvalidArgument("stop must be a scalar, not shape ",
+ stop_in_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape),
+ errors::InvalidArgument("num must be a scalar, not shape ",
+ num_in_shape.DebugString()));
+
+ DataType type = ctx->input_type(0);
+
+ int64 num;
+ OP_REQUIRES_OK(ctx, GetIntValue(2, ctx, &num));
+ OP_REQUIRES(ctx, num > 0,
+ errors::InvalidArgument("Requires num > 0: ", num));
+ Tensor out_constant(type, TensorShape({num}));
+
+ switch (type) {
+ case DT_FLOAT: {
+ float start, stop;
+ OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start));
+ OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
+ auto flat = out_constant.flat<float>();
+ if (num == 1) {
+ flat(0) = start;
+ } else {
+ const float step = (stop - start) / (num - 1);
+ for (int64 i = 0; i < num; ++i) {
+ flat(i) = start + step * i;
+ }
+ }
+ break;
+ }
+ case DT_DOUBLE: {
+ double start, stop;
+ OP_REQUIRES_OK(ctx, GetValue(0, ctx, &start));
+ OP_REQUIRES_OK(ctx, GetValue(1, ctx, &stop));
+ auto flat = out_constant.flat<double>();
+ if (num == 1) {
+ flat(0) = start;
+ } else {
+ const double step = (stop - start) / (num - 1);
+ for (int64 i = 0; i < num; ++i) {
+ flat(i) = start + step * i;
+ }
+ }
+ break;
+ }
+
+ default:
+ ctx->SetStatus(errors::InvalidArgument("Invalid argument type ",
+ DataTypeString(type)));
+ return;
+ }
+ ctx->SetConstantOutput(0, out_constant);
+ }
+};
+
+REGISTER_XLA_OP("LinSpace", LinSpaceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
new file mode 100644
index 0000000000..e7eec1cefd
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -0,0 +1,245 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Shape Ops.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+namespace {
+
+class ShapeOp : public XlaOpKernel {
+ public:
+ explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const int rank = input_shape.dims();
+ Tensor shape_constant(DT_INT32, TensorShape({rank}));
+ auto vec = shape_constant.vec<int32>();
+ // TODO(dga): support int64. b/28119922.
+ for (int i = 0; i < rank; ++i) {
+ int64 dim_size = input_shape.dim_size(i);
+ OP_REQUIRES(
+ ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("Shape does not support tensors > int32max",
+ " but dim ", i, " is ", dim_size));
+ vec(i) = static_cast<int32>(dim_size);
+ }
+
+ ctx->SetConstantOutput(0, shape_constant);
+ }
+};
+
+REGISTER_XLA_OP("Shape", ShapeOp);
+
+class ShapeNOp : public XlaOpKernel {
+ public:
+ explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ const TensorShape shape = ctx->InputShape(i);
+ const int dims = shape.dims();
+ Tensor shape_constant(DT_INT32, TensorShape({dims}));
+ auto vec = shape_constant.vec<int32>();
+
+ // TODO(dga): support int64. b/28119922.
+ for (int j = 0; j < dims; ++j) {
+ int64 dim_size = shape.dim_size(j);
+ OP_REQUIRES(
+ ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("Shape does not support tensors > int32max",
+ " but shape ", i, " dim ", j, " is ",
+ dim_size));
+ vec(j) = static_cast<int32>(dim_size);
+ }
+
+ ctx->SetConstantOutput(i, shape_constant);
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+};
+REGISTER_XLA_OP("ShapeN", ShapeNOp);
+
+class RankOp : public XlaOpKernel {
+ public:
+ explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const int rank = input_shape.dims();
+ Tensor rank_constant(DT_INT32, TensorShape({}));
+ rank_constant.scalar<int32>()() = rank;
+
+ ctx->SetConstantOutput(0, rank_constant);
+ }
+};
+
+REGISTER_XLA_OP("Rank", RankOp);
+
+class SizeOp : public XlaOpKernel {
+ public:
+ explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const int64 size = input_shape.num_elements();
+ OP_REQUIRES(ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("Size does not work for tensors > "
+ "int32 max."));
+ Tensor size_constant(DT_INT32, TensorShape({}));
+ size_constant.scalar<int32>()() = static_cast<int32>(size);
+
+ ctx->SetConstantOutput(0, size_constant);
+ }
+};
+
+REGISTER_XLA_OP("Size", SizeOp);
+
+class ExpandDimsOp : public XlaOpKernel {
+ public:
+ explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape dim_shape = ctx->InputShape(1);
+
+ // TODO(phawkins): the standard implementation of ExpandDimsOp seems to
+ // accept legacy scalars, even when they should be forbidden by the graphdef
+ // version.
+ OP_REQUIRES(ctx, dim_shape.num_elements() == 1,
+ errors::InvalidArgument(strings::StrCat(
+ "dim input to ExpandDims must be a scalar; got ",
+ dim_shape.DebugString())));
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal));
+
+ int dim = literal.s32s(0);
+
+ OP_REQUIRES(ctx,
+ (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()),
+ errors::InvalidArgument("Tried to expand dim index ", dim,
+ " for tensor with ", input_shape.dims(),
+ " dimensions."));
+
+ auto existing_dims = input_shape.dim_sizes();
+ // Safe - # elements in tensor dims bounded.
+ const int existing_dims_size = static_cast<int>(existing_dims.size());
+ std::vector<int64> new_shape(existing_dims_size);
+ for (size_t i = 0; i < new_shape.size(); ++i) {
+ new_shape[i] = existing_dims[i];
+ }
+
+ // We emulate numpy's interpretation of the dim axis when
+ // -input.dims() >= dim <= input.dims().
+ if (dim < 0) {
+ dim += existing_dims.size() + 1;
+ }
+
+ // Clamp to the end if needed.
+ dim = std::min<int32>(dim, existing_dims_size);
+ new_shape.emplace(new_shape.begin() + dim, 1);
+
+ ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
+ }
+};
+REGISTER_XLA_OP("ExpandDims", ExpandDimsOp);
+
+class SqueezeOp : public XlaOpKernel {
+ public:
+ explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ std::vector<int32> squeeze_dims;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
+ squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ auto existing_dims = input_shape.dim_sizes();
+ int existing_dims_size = input_shape.dims();
+ std::vector<int64> new_shape;
+
+ std::unordered_set<int32> wrapped_squeeze_dims;
+ wrapped_squeeze_dims.reserve(squeeze_dims_.size());
+ // Validate squeeze dims against the input.
+ for (int32 dim : squeeze_dims_) {
+ OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()),
+ errors::InvalidArgument("Tried to squeeze dim index ", dim,
+ " for tensor with ",
+ input_shape.dims(), " dimensions."));
+ // If dim is < 0, we wrap around (-1 means the last element).
+ if (dim < 0) {
+ dim = existing_dims_size + dim;
+ }
+
+ wrapped_squeeze_dims.insert(dim);
+ }
+
+ for (int i = 0; i < existing_dims_size; ++i) {
+ auto existing_dim = existing_dims[i];
+
+ // If squeeze_set is non-empty, only squeeze those dimensions.
+ if (!wrapped_squeeze_dims.empty()) {
+ if (wrapped_squeeze_dims.count(i) > 0) {
+ OP_REQUIRES(ctx, existing_dim == 1,
+ errors::InvalidArgument("Tried to explicitly squeeze "
+ "dimension ",
+ i, " but dimension was not 1: ",
+ existing_dim));
+ } else {
+ // This dimension is not being squeezed.
+ new_shape.push_back(existing_dim);
+ }
+ } else {
+ // Copy over all non-1-length dimensions.
+ if (existing_dim != 1) {
+ new_shape.push_back(existing_dim);
+ }
+ }
+ }
+
+ ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape));
+ }
+
+ private:
+ std::unordered_set<int32> squeeze_dims_;
+};
+
+REGISTER_XLA_OP("Squeeze", SqueezeOp);
+
+class ZerosLikeOp : public XlaOpKernel {
+ public:
+ explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0));
+ ctx->SetOutput(0, ctx->builder()->Broadcast(zero, input_shape.dim_sizes()));
+ }
+};
+
+REGISTER_XLA_OP("ZerosLike", ZerosLikeOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
new file mode 100644
index 0000000000..8ec77e04af
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc
@@ -0,0 +1,121 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Slice Op.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mem.h"
+
+namespace tensorflow {
+namespace {
+
+class SliceOp : public XlaOpKernel {
+ public:
+ explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ bool is_identity = true;
+ std::vector<int64> begin;
+ std::vector<int64> size;
+ SharedValidation(ctx, &is_identity, &begin, &size);
+ if (!ctx->status().ok()) return;
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ // slice will be an empty handle if the output has no elements.
+ CHECK_EQ(begin.size(), size.size());
+ std::vector<int64> limits;
+ for (int i = 0; i < begin.size(); ++i) {
+ limits.push_back(begin[i] + size[i]);
+ }
+ ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits));
+ }
+
+ private:
+ void SharedValidation(XlaOpKernelContext* ctx, bool* is_identity,
+ std::vector<int64>* begin, std::vector<int64>* size);
+};
+
+void SliceOp::SharedValidation(XlaOpKernelContext* ctx, bool* is_identity,
+ std::vector<int64>* begin,
+ std::vector<int64>* size) {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape begin_tensor_shape = ctx->InputShape(1);
+ const TensorShape size_tensor_shape = ctx->InputShape(2);
+
+ OP_REQUIRES(
+ ctx,
+ IsLegacyVector(begin_tensor_shape) && IsLegacyVector(size_tensor_shape) &&
+ begin_tensor_shape.num_elements() == input_shape.dims() &&
+ size_tensor_shape.num_elements() == input_shape.dims(),
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_shape.dims(), ", but got shapes ",
+ begin_tensor_shape.DebugString(), " and ",
+ size_tensor_shape.DebugString(), " instead."));
+
+ const int input_dims = input_shape.dims();
+
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, begin));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, size));
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_shape.dim_size(i) == 0) {
+ OP_REQUIRES(ctx, b == 0 && s == 0,
+ errors::InvalidArgument(
+ "Expected begin[", i, "] == 0 (got ", b, ") and size[", i,
+ "] == 0 ", "(got ", s, ") when ", "input_shape.dim_size(",
+ i, ") == 0"));
+ } else {
+ OP_REQUIRES(
+ ctx, 0 <= b && b <= input_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_shape.dim_size(i), "], but got ", b));
+ OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+REGISTER_XLA_OP("Slice", SliceOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
new file mode 100644
index 0000000000..06ee520163
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc
@@ -0,0 +1,152 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for softmax.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class SoftmaxOp : public XlaOpKernel {
+ public:
+ explicit SoftmaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ log_ = StringPiece(type_string()).starts_with("Log");
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape logits_shape = ctx->InputShape(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const DataType type = input_type(0);
+ auto logits = ctx->Input(0);
+
+ xla::ComputationBuilder* b = ctx->builder();
+ const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+ const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
+
+ // Find the max in each batch, resulting in a tensor of shape [batch]
+ auto logits_max =
+ b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+ // Subtract the max in batch b from every element in batch b. Broadcasts
+ // along the batch dimension.
+ auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
+ xla::ComputationDataHandle softmax;
+ if (log_) {
+ // softmax = shifted_logits - log(sum(exp(shifted_logits)))
+ auto log_sum_exp =
+ b->Log(b->Reduce(b->Exp(shifted_logits), XlaHelpers::Zero(b, type),
+ add_func, {kClassDim}));
+ softmax = b->Sub(shifted_logits, log_sum_exp, {kBatchDim});
+ } else {
+ // softmax = exp(shifted_logits) / sum(exp(shifted_logits))
+ auto exp_shifted = b->Exp(shifted_logits);
+ auto sum_exp = b->Reduce(exp_shifted, XlaHelpers::Zero(b, type), add_func,
+ {kClassDim});
+ softmax = b->Div(exp_shifted, sum_exp, {kBatchDim});
+ }
+
+ ctx->SetOutput(0, softmax);
+ }
+
+ private:
+ bool log_;
+};
+
+REGISTER_XLA_OP("Softmax", SoftmaxOp);
+REGISTER_XLA_OP("LogSoftmax", SoftmaxOp);
+
+class SoftmaxXentWithLogitsOp : public XlaOpKernel {
+ public:
+ explicit SoftmaxXentWithLogitsOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape logits_shape = ctx->InputShape(0);
+ const TensorShape labels_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, logits_shape.IsSameSize(labels_shape),
+ errors::InvalidArgument(
+ "logits and labels must be same size: logits_size=",
+ logits_shape.DebugString(), " labels_size=",
+ labels_shape.DebugString()));
+ OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape),
+ errors::InvalidArgument("logits must be 2-dimensional"));
+ // As we already tested that both inputs have the same shape no need to
+ // check that "labels" is a matrix too.
+
+ // loss is 1-D (one per example), and size is batch_size.
+
+ const int kBatchDim = 0;
+ const int kClassDim = 1;
+
+ const DataType type = input_type(0);
+ xla::ComputationBuilder* b = ctx->builder();
+ auto logits = ctx->Input(0);
+ auto labels = ctx->Input(1);
+
+ const xla::Computation& max_func = *ctx->GetOrCreateMax(type);
+ const xla::Computation& add_func = *ctx->GetOrCreateAdd(type);
+
+ // Find the max in each batch, resulting in a tensor of shape [batch]
+ auto logits_max =
+ b->Reduce(logits, XlaHelpers::MinValue(b, type), max_func, {kClassDim});
+
+ // Subtract the max in batch b from every element in batch b.
+ // Broadcasts along the batch dimension.
+ auto shifted_logits = b->Sub(logits, logits_max, {kBatchDim});
+
+ // exp(logits - max_logits)
+ auto exp_shifted_logits = b->Exp(shifted_logits);
+
+ // sum_{class} (exp(logits - max_logits))
+ auto sum_exp = b->Reduce(exp_shifted_logits, XlaHelpers::Zero(b, type),
+ add_func, {kClassDim});
+
+ // log(sum(exp(logits - max_logits)))
+ auto log_sum_exp = b->Log(sum_exp);
+
+ // sum(-labels *
+ // ((logits - max_logits) - log(sum(exp(logits - max_logits)))))
+ // along classes
+ // (The subtraction broadcasts along the batch dimension.)
+ xla::ComputationDataHandle loss =
+ b->Reduce(b->Mul(b->Neg(labels),
+ b->Sub(shifted_logits, log_sum_exp, {kBatchDim})),
+ XlaHelpers::Zero(b, type), add_func, {kClassDim});
+
+ // backprop: prob - labels, where
+ // prob = exp(logits - max_logits) / sum(exp(logits - max_logits))
+ // (where the division broadcasts along the batch dimension)
+ xla::ComputationDataHandle backprop =
+ b->Sub(b->Div(exp_shifted_logits, sum_exp, {kBatchDim}), labels);
+
+ ctx->SetOutput(0, loss);
+ ctx->SetOutput(1, backprop);
+ }
+};
+
+REGISTER_XLA_OP("SoftmaxCrossEntropyWithLogits", SoftmaxXentWithLogitsOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc
new file mode 100644
index 0000000000..18c4c648db
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Ops for split.
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace {
+
+class SplitOp : public XlaOpKernel {
+ public:
+ explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape index_shape = ctx->InputShape(0);
+ xla::Literal literal_index;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index));
+
+ int32 split_dim;
+ if (index_shape.dims() == 0) {
+ split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
+ } else {
+ OP_REQUIRES(
+ ctx, index_shape.dims() == 1,
+ errors::InvalidArgument("split_index input to Split Op must be a "
+ "scalar or a vector with 1 element"));
+ OP_REQUIRES(
+ ctx, index_shape.dim_size(0) == 1,
+ errors::InvalidArgument("split_index input to Split Op must be a "
+ "scalar or a vector with 1 element"));
+ split_dim = xla::LiteralUtil::Get<int>(literal_index, {0});
+ }
+ const int32 num_split = num_outputs();
+ const TensorShape input_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(
+ ctx, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
+ input_shape.dims(), "), but got ", split_dim));
+
+ OP_REQUIRES(
+ ctx, num_split > 0,
+ errors::InvalidArgument(
+ "Number of ways to split should be > 0, but got ", num_split));
+
+ OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0,
+ errors::InvalidArgument(
+ "Number of ways to split should evenly divide the split "
+ "dimension, but got split_dim ",
+ split_dim, " (size = ", input_shape.dim_size(split_dim),
+ ") ", "and num_split ", num_split));
+
+ // All the slices are the same size: this is the size along the
+ // split dimension.
+ const int32 slice_size = input_shape.dim_size(split_dim) / num_split;
+
+ // The vectors we will use to define the slice. The entry for the
+ // split dimensions varies for each output.
+ std::vector<int64> begin;
+ std::vector<int64> limits;
+ for (int i = 0; i < input_shape.dims(); ++i) {
+ // Initially set up the limits to be the full size of the input:
+ // the split dimension is filled in below.
+ int64 dim = input_shape.dim_size(i);
+ begin.push_back(0);
+ limits.push_back(dim);
+ }
+
+ auto input = ctx->Input(1);
+
+ // Create each of the outputs.
+ for (int i = 0; i < num_split; ++i) {
+ // Slice out the ith split from the split dimension.
+ begin[split_dim] = i * slice_size;
+ limits[split_dim] = (i + 1) * slice_size;
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ }
+ }
+};
+
+REGISTER_XLA_OP("Split", SplitOp);
+
+class SplitVOp : public XlaOpKernel {
+ public:
+ explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const int32 num_split = num_outputs();
+ const TensorShape index_shape = ctx->InputShape(2);
+ xla::Literal literal_index;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &literal_index));
+
+ int32 split_dim;
+ OP_REQUIRES(ctx, index_shape.dims() == 0,
+ errors::InvalidArgument("split_dim input to Split Op must be a "
+ "scalar"));
+ split_dim = xla::LiteralUtil::Get<int>(literal_index, {});
+
+ xla::ComputationDataHandle input = ctx->Input(0);
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ OP_REQUIRES(ctx, input_shape.dims() > 0,
+ errors::InvalidArgument("Can't split a 0 dimensional input"));
+
+ OP_REQUIRES(
+ ctx, 0 <= split_dim && split_dim < input_shape.dims(),
+ errors::InvalidArgument("0 <= split_dim < number of input dimensions (",
+ input_shape.dims(), "), but got ", split_dim));
+
+ OP_REQUIRES(
+ ctx, num_split > 0,
+ errors::InvalidArgument(
+ "Number of ways to split should be > 0, but got ", num_split));
+
+ // check that sizes are correct
+ int total_split_size = 0;
+ int neg_one_dim = -1;
+ std::vector<int64> split_sizes_vec(num_split, -1);
+ const TensorShape split_size_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, split_size_shape.dims() == 1 &&
+ split_size_shape.num_elements() == num_split,
+ errors::InvalidArgument(
+ "shape of tensor describing "
+ " the output must have dimension 1 and the same "
+ " number of elements as the output. Got ",
+ split_size_shape.dims(), "-D and ",
+ split_size_shape.num_elements(), " elements"));
+ // get the dimension of this split
+ xla::Literal split_size_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &split_size_literal));
+
+ for (int i = 0; i < num_split; ++i) {
+ int slice_size;
+ slice_size = xla::LiteralUtil::Get<int>(split_size_literal, {i});
+ if (slice_size == -1) {
+ OP_REQUIRES(
+ ctx, neg_one_dim == -1,
+ errors::InvalidArgument("Only one dimensions can have a value of"
+ "-1. Second one found at dimension ",
+ i));
+ neg_one_dim = i;
+ } else {
+ split_sizes_vec[i] = slice_size;
+ total_split_size += slice_size;
+ }
+ }
+
+ OP_REQUIRES(
+ ctx, (neg_one_dim == -1 &&
+ total_split_size == input_shape.dim_size(split_dim)) ||
+ (neg_one_dim >= 0 &&
+ total_split_size <= input_shape.dim_size(split_dim)),
+ errors::InvalidArgument("Determined shape must either match "
+ "input shape along split_dim exactly if "
+ "fully specified, or be less than the size of "
+ "the input along split_dim if not fully "
+ "specified. Got: ",
+ total_split_size));
+
+ if (neg_one_dim >= 0) {
+ split_sizes_vec[neg_one_dim] =
+ input_shape.dim_size(split_dim) - total_split_size;
+ }
+
+ // The vectors we will use to define the slice. The entry for the
+ // split dimensions varies for each output.
+ std::vector<int64> begin(input_shape.dims(), 0);
+ auto dim_sizes = input_shape.dim_sizes();
+ std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
+
+ for (int i = 0; i < num_split; ++i) {
+ TensorShape output_shape(input_shape);
+ int slice_size = split_sizes_vec[i];
+ output_shape.set_dim(split_dim, slice_size);
+
+ // Slice out the ith split from the split dimension.
+ limits[split_dim] = begin[split_dim] + slice_size;
+ ctx->SetOutput(i, ctx->builder()->Slice(input, begin, limits));
+ begin[split_dim] = limits[split_dim];
+ }
+ }
+};
+
+REGISTER_XLA_OP("SplitV", SplitVOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
new file mode 100644
index 0000000000..83bf24814f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/strided_slice_op.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mem.h"
+
+namespace tensorflow {
+namespace {
+
+class StridedSliceOp : public XlaOpKernel {
+ public:
+ explicit StridedSliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ TensorShape final_shape;
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> end;
+ gtl::InlinedVector<int64, 4> strides;
+
+ xla::Literal begin_literal, end_literal, strides_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
+
+ Tensor begin_tensor, end_tensor, strides_tensor;
+ OP_REQUIRES_OK(
+ ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
+ OP_REQUIRES_OK(ctx,
+ LiteralToHostTensor(end_literal, index_type_, &end_tensor));
+ OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
+ &strides_tensor));
+
+ TensorShape dummy_processing_shape;
+ ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
+ ShapeReadWriteFromTensorShape wrapped_dummy_processing_shape(
+ &dummy_processing_shape);
+ bool dummy = false;
+ OP_REQUIRES_OK(
+ ctx, ValidateStridedSliceOp(
+ &begin_tensor, &end_tensor, strides_tensor,
+ ShapeReadWriteFromTensorShape(&input_shape), begin_mask_,
+ end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
+ &wrapped_dummy_processing_shape, &wrapped_final_shape, &dummy,
+ &dummy, &dummy, &begin, &end, &strides));
+
+ gtl::InlinedVector<int64, 4> dimensions_to_reverse;
+ gtl::InlinedVector<int64, 4> slice_begin, slice_end;
+ for (int i = 0; i < begin.size(); ++i) {
+ // TODO(phawkins): implement strides != 1 when b/30878775 is fixed.
+ OP_REQUIRES(
+ ctx, strides[i] == 1 || strides[i] == -1,
+ errors::Unimplemented("Strides != 1 or -1 are not yet implemented"));
+ if (strides[i] > 0) {
+ slice_begin.push_back(begin[i]);
+ slice_end.push_back(end[i]);
+ } else {
+ // Negative stride: swap begin and end, add 1 because the interval
+ // is semi-open, and mark the dimension to be reversed.
+ slice_begin.push_back(end[i] + 1);
+ slice_end.push_back(begin[i] + 1);
+ dimensions_to_reverse.push_back(i);
+ }
+ }
+ xla::ComputationDataHandle slice =
+ ctx->builder()->Slice(ctx->Input(0), slice_begin, slice_end);
+ if (!dimensions_to_reverse.empty()) {
+ slice = ctx->builder()->Rev(slice, dimensions_to_reverse);
+ }
+
+ slice = ctx->builder()->Reshape(slice, final_shape.dim_sizes());
+ ctx->SetOutput(0, slice);
+ }
+
+ private:
+ int32 begin_mask_, end_mask_;
+ int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
+ DataType index_type_;
+};
+
+REGISTER_XLA_OP("StridedSlice", StridedSliceOp);
+
+class StridedSliceGradOp : public XlaOpKernel {
+ public:
+ explicit StridedSliceGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("begin_mask", &begin_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("end_mask", &end_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("ellipsis_mask", &ellipsis_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape processing_shape, final_shape;
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> end;
+ gtl::InlinedVector<int64, 4> strides;
+
+ TensorShape input_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
+
+ xla::Literal begin_literal, end_literal, strides_literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &begin_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &end_literal));
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(3, &strides_literal));
+
+ Tensor begin_tensor, end_tensor, strides_tensor;
+ OP_REQUIRES_OK(
+ ctx, LiteralToHostTensor(begin_literal, index_type_, &begin_tensor));
+ OP_REQUIRES_OK(ctx,
+ LiteralToHostTensor(end_literal, index_type_, &end_tensor));
+ OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
+ &strides_tensor));
+
+ bool dummy = false;
+ ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
+ ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape);
+ OP_REQUIRES_OK(
+ ctx, ValidateStridedSliceOp(
+ &begin_tensor, &end_tensor, strides_tensor,
+ ShapeReadWriteFromTensorShape(&input_shape), begin_mask_,
+ end_mask_, ellipsis_mask_, new_axis_mask_, shrink_axis_mask_,
+ &wrapped_processing_shape, &wrapped_final_shape, &dummy,
+ &dummy, &dummy, &begin, &end, &strides));
+
+ // Check to make sure dy is consistent with the original slice
+ const TensorShape dy_shape = ctx->InputShape(4);
+ OP_REQUIRES(
+ ctx, final_shape == dy_shape,
+ errors::InvalidArgument("shape of dy was ", dy_shape.DebugString(),
+ " instead of ", final_shape.DebugString()));
+
+ OP_REQUIRES(
+ ctx, input_shape.dims() == processing_shape.dims(),
+ errors::Internal(
+ "input shape and processing shape must have same number of dims"));
+
+ auto zero = XlaHelpers::Zero(ctx->builder(), ctx->expected_output_dtype(0));
+
+ xla::ComputationDataHandle grad = ctx->Input(4);
+
+ // Undo any new/shrink axes.
+ grad = ctx->builder()->Reshape(grad, processing_shape.dim_sizes());
+
+ // Pad the input gradients.
+ gtl::InlinedVector<int64, 4> dimensions_to_reverse;
+ xla::PaddingConfig padding_config;
+
+ for (int i = 0; i < processing_shape.dims(); ++i) {
+ auto* dims = padding_config.add_dimensions();
+ if (strides[i] > 0) {
+ dims->set_edge_padding_low(begin[i]);
+ dims->set_interior_padding(strides[i] - 1);
+
+ // Pad the upper dimension up to the expected input shape. (It's
+ // not sufficient simply to use "end[i]" to compute the padding in
+ // cases where the stride does not divide evenly into the interval
+ // between begin[i] and end[i].)
+ int64 size =
+ dims->edge_padding_low() + processing_shape.dim_size(i) +
+ (processing_shape.dim_size(i) - 1) * dims->interior_padding();
+ dims->set_edge_padding_high(input_shape.dim_size(i) - size);
+ } else {
+ dimensions_to_reverse.push_back(i);
+ dims->set_edge_padding_high(input_shape.dim_size(i) - begin[i] - 1);
+ dims->set_interior_padding(-strides[i] - 1);
+
+ // Pad the lower dimension up to the expected input shape.
+ int64 size =
+ dims->edge_padding_high() + processing_shape.dim_size(i) +
+ (processing_shape.dim_size(i) - 1) * dims->interior_padding();
+ dims->set_edge_padding_low(input_shape.dim_size(i) - size);
+ }
+ }
+ if (!dimensions_to_reverse.empty()) {
+ grad = ctx->builder()->Rev(grad, dimensions_to_reverse);
+ }
+ grad = ctx->builder()->Pad(grad, zero, padding_config);
+ ctx->SetOutput(0, grad);
+ }
+
+ private:
+ int32 begin_mask_, end_mask_;
+ int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
+ DataType index_type_;
+};
+
+REGISTER_XLA_OP("StridedSliceGrad", StridedSliceGradOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
new file mode 100644
index 0000000000..45ac5e12c7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Tile Op.
+
+#include <vector>
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace {
+
+// --------------------------------------------------------------------------
+class TileOp : public XlaOpKernel {
+ public:
+ explicit TileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape multiples_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(
+ ctx, IsLegacyVector(multiples_shape),
+ errors::InvalidArgument("Expected multiples to be 1-D, but got shape ",
+ multiples_shape.DebugString()));
+ OP_REQUIRES(ctx, input_shape.dims() == multiples_shape.num_elements(),
+ errors::InvalidArgument(
+ "Expected multiples argument to be a vector of length ",
+ input_shape.dims(), " but got length ",
+ multiples_shape.dim_size(0)));
+ const int input_dims = input_shape.dims();
+
+ // If input is a scalar then multiples has 0 elements and this is
+ // a NoOp.
+ if (input_dims == 0) {
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal));
+
+ // zero_element_result is true if the final shape has 0 elements,
+ // i.e. if any of the input dimensions or multiples is zero.
+ std::vector<int64> multiples_array(input_dims);
+ std::vector<int64> output_shape;
+ bool all_multiples_are_one = true;
+ bool one_dimension_is_broadcasted_without_multiple = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int multiple = xla::LiteralUtil::Get<int>(literal, {i});
+ OP_REQUIRES(ctx, multiple,
+ errors::InvalidArgument("Expected multiples[", i,
+ "] >= 0, but got ", multiple));
+ int64 new_dim = input_shape.dim_size(i) * multiple;
+ output_shape.push_back(new_dim);
+ multiples_array[i] = multiple;
+ all_multiples_are_one = all_multiples_are_one && multiple == 1;
+ // If the multiple of a non-one dimensions is not one, then binary
+ // operation broadcast semantics will not be sufficient to implement the
+ // tile operation.
+ one_dimension_is_broadcasted_without_multiple =
+ one_dimension_is_broadcasted_without_multiple &&
+ ((input_shape.dim_size(i) > 1 && multiple == 1) ||
+ input_shape.dim_size(i) == 1);
+ }
+ auto input = ctx->Input(0);
+ // If all multiples are 1, than the input is the same as the output.
+ if (all_multiples_are_one) {
+ ctx->SetOutput(0, input);
+ return;
+ }
+ if (one_dimension_is_broadcasted_without_multiple) {
+ // Create a constant Zero the size of the output shape to leverage binary
+ // operation broadcast semantics.
+ auto broadcasted_zero = ctx->builder()->Broadcast(
+ XlaHelpers::Zero(ctx->builder(), ctx->input_type(0)), output_shape);
+ ctx->SetOutput(0, ctx->builder()->Add(broadcasted_zero, input));
+ return;
+ }
+
+ // First broadcast the requisite number of multiples along each
+ // dimension. This prepends the broadcasted dimensions, so an
+ // input of shape [2,3,1] broadcast with multiples [5,4,3] will
+ // end up with shape [5,4,3,2,3,1].
+ auto broadcasted = ctx->builder()->Broadcast(input, multiples_array);
+ // Now flatten and reshape. The broadcasted dimensions are
+ // paired with the original dimensions so in the above example
+ // we flatten [0,3,1,4,2,5] then reshape to [10,12,3].
+ std::vector<int64> flattened;
+ for (int i = 0; i < output_shape.size(); ++i) {
+ flattened.push_back(i);
+ flattened.push_back(i + output_shape.size());
+ }
+ xla::ComputationDataHandle output =
+ ctx->builder()->Reshape(broadcasted, flattened, output_shape);
+
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TileOp);
+};
+
+REGISTER_XLA_OP("Tile", TileOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
new file mode 100644
index 0000000000..2840abc878
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc
@@ -0,0 +1,134 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA-specific Transpose Op. This is very different to the Eigen
+// version in third_party/tensorflow because XLA's reshape neatly
+// handles all transposes, while Eigen needs a restricted DoTranspose
+// helper.
+
+#include "tensorflow/core/kernels/transpose_op.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+
+namespace tensorflow {
+namespace {
+
+class TransposeOp : public XlaOpKernel {
+ public:
+ explicit TransposeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape perm_tensor_shape = ctx->InputShape(1);
+
+ // Preliminary validation of sizes.
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(perm_tensor_shape),
+ errors::InvalidArgument("perm must be a vector, not ",
+ perm_tensor_shape.DebugString()));
+
+ const int dims = input_shape.dims();
+ OP_REQUIRES(ctx, dims == perm_tensor_shape.num_elements(),
+ errors::InvalidArgument("transpose expects a vector of size ",
+ input_shape.dims(),
+ ". But input(1) is a vector of size ",
+ perm_tensor_shape.num_elements()));
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal));
+
+ std::vector<int32> perm(dims);
+ std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin());
+
+ std::vector<int64> transposed_order;
+ // Check whether permutation is a permutation of integers of [0 .. dims).
+ gtl::InlinedVector<bool, 8> bits(dims);
+ bool is_identity = true;
+ for (int i = 0; i < dims; ++i) {
+ const int32 d = perm[i];
+ OP_REQUIRES(
+ ctx, 0 <= d && d < dims,
+ errors::InvalidArgument(d, " is out of range [0 .. ", dims, ")"));
+ bits[d] = true;
+ transposed_order.push_back(d);
+ if (d != i) {
+ is_identity = false;
+ }
+ }
+ for (int i = 0; i < dims; ++i) {
+ OP_REQUIRES(ctx, bits[i], errors::InvalidArgument(
+ i, " is missing from 'perm' argument."));
+ }
+
+ // 0-D, 1-D, and identity transposes do nothing.
+ if (dims <= 1 || is_identity) {
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ ctx->SetOutput(0,
+ ctx->builder()->Transpose(ctx->Input(0), transposed_order));
+ }
+};
+
+REGISTER_XLA_OP("Transpose", TransposeOp);
+
+// InvertPermutation frequently forms part of the gradient of Transpose.
+//
+// inv = InvertPermutationOp(T<int32> p) takes a permutation of
+// integers 0, 1, ..., n - 1 and returns the inverted
+// permutation of p. I.e., inv[p[i]] == i, for i in [0 .. n).
+//
+// REQUIRES: input is a vector of int32.
+// REQUIRES: input is a permutation of 0, 1, ..., n-1.
+
+class InvertPermutationOp : public XlaOpKernel {
+ public:
+ explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(),
+ std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("permutation of nonnegative int32s "
+ "must have <= int32 max elements"));
+
+ std::vector<int64> perm;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &perm));
+
+ int size = perm.size();
+
+ std::vector<int32> output(size);
+ std::fill_n(output.data(), size, -1);
+ for (int i = 0; i < size; ++i) {
+ const int64 d = perm[i];
+ OP_REQUIRES(ctx, FastBoundsCheck(d, size),
+ errors::InvalidArgument(d, " is not between 0 and ", size));
+ OP_REQUIRES(ctx, output[d] == -1,
+ errors::InvalidArgument(d, " is duplicated in the input."));
+ output[d] = i;
+ }
+
+ ctx->SetOutput(0, ctx->builder()->ConstantR1<int32>(output));
+ }
+};
+
+REGISTER_XLA_OP("InvertPermutation", InvertPermutationOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
new file mode 100644
index 0000000000..eced089b32
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -0,0 +1,70 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Native XLA implementations of simple unary Ops
+
+#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+
+namespace tensorflow {
+namespace {
+
+// A subclass of a TlaUnaryOp must build the lambda computation that
+// describes the scalar->scalar function to apply to each element of
+// the input.
+#define XLAJIT_MAKE_UNARY(Name, COMPUTATION) \
+ class Name##Op : public XlaOpKernel { \
+ public: \
+ explicit Name##Op(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} \
+ void Compile(XlaOpKernelContext* ctx) { \
+ xla::ComputationBuilder& b = *ctx->builder(); \
+ xla::ComputationDataHandle x = ctx->Input(0); \
+ xla::ComputationDataHandle y = COMPUTATION; \
+ ctx->SetOutput(0, y); \
+ } \
+ }; \
+ REGISTER_XLA_OP(#Name, Name##Op);
+
+// Return x if x>0, otherwise -x.
+XLAJIT_MAKE_UNARY(Abs, b.Abs(x));
+XLAJIT_MAKE_UNARY(Ceil, b.Ceil(x));
+XLAJIT_MAKE_UNARY(Exp, b.Exp(x));
+XLAJIT_MAKE_UNARY(Floor, b.Floor(x));
+// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
+XLAJIT_MAKE_UNARY(Sign, b.Sign(x));
+// Return 1/x
+XLAJIT_MAKE_UNARY(Inv, b.Div(XlaHelpers::One(&b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Reciprocal, b.Div(XlaHelpers::One(&b, input_type(0)), x));
+XLAJIT_MAKE_UNARY(Log, b.Log(x));
+XLAJIT_MAKE_UNARY(LogicalNot, b.LogicalNot(x));
+XLAJIT_MAKE_UNARY(Neg, b.Neg(x));
+XLAJIT_MAKE_UNARY(Rsqrt,
+ b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), -0.5)));
+XLAJIT_MAKE_UNARY(Sigmoid, b.Map({x}, *ctx->GetOrCreateSigmoid(input_type(0))));
+XLAJIT_MAKE_UNARY(Softplus,
+ b.Log(b.Add(b.Exp(x), XlaHelpers::One(&b, input_type(0)))));
+XLAJIT_MAKE_UNARY(Sqrt,
+ b.Pow(x, XlaHelpers::FloatLiteral(&b, input_type(0), 0.5)));
+XLAJIT_MAKE_UNARY(Square, b.Mul(x, x));
+XLAJIT_MAKE_UNARY(Tanh, b.Tanh(x));
+
+#undef XLAJIT_MAKE_UNARY
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unpack_op.cc b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
new file mode 100644
index 0000000000..c5b2bdaf2d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/unpack_op.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA Unpack operator.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+class UnpackOp : public XlaOpKernel {
+ public:
+ explicit UnpackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const int num = num_outputs();
+ const TensorShape input_shape = ctx->InputShape(0);
+
+ int axis = axis_;
+ if (axis < 0) axis += input_shape.dims();
+
+ OP_REQUIRES(ctx, 0 <= axis && axis < input_shape.dims(),
+ errors::InvalidArgument("axis = ", axis_, " not in [",
+ -input_shape.dims(), ", ",
+ input_shape.dims(), ")"));
+
+ OP_REQUIRES(
+ ctx, input_shape.dims() > 0 && input_shape.dim_size(axis) == num,
+ errors::InvalidArgument("Input shape axis ", axis, " must equal ", num,
+ ", got shape ", input_shape.DebugString()));
+
+ auto output_shape = input_shape;
+ output_shape.RemoveDim(axis);
+
+ auto input = ctx->Input(0);
+
+ std::vector<int64> start_indices(input_shape.dims(), 0);
+ std::vector<int64> limit_indices(input_shape.dims());
+ for (int i = 0; i < input_shape.dims(); ++i) {
+ limit_indices[i] = input_shape.dim_size(i);
+ }
+
+ for (int i = 0; i < num; ++i) {
+ start_indices[axis] = i;
+ limit_indices[axis] = i + 1;
+ auto slice = ctx->builder()->Slice(input, start_indices, limit_indices);
+ // Reshape to drop the 'axis' dimension.
+ auto result = ctx->builder()->Reshape(slice, output_shape.dim_sizes());
+ ctx->SetOutput(i, result);
+ }
+ }
+
+ private:
+ int axis_;
+};
+
+REGISTER_XLA_OP("Unpack", UnpackOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
new file mode 100644
index 0000000000..1f2bc01cf4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+
+namespace tensorflow {
+
+Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
+ literal->Clear();
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
+ host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape()));
+
+ xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal);
+
+ // memcpy over the payload ...
+ // TODO(phawkins): handle string types.
+ size_t total_bytes = host_tensor.TotalBytes();
+ if (total_bytes > 0) {
+ void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal);
+ const void* src_ptr = DMAHelper::base(&host_tensor);
+ memcpy(dst_ptr, src_ptr, total_bytes);
+ }
+ return Status::OK();
+}
+
+Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
+ Tensor* host_tensor) {
+ xla::PrimitiveType primitive_type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type));
+ if (literal.shape().element_type() != primitive_type) {
+ return errors::InvalidArgument(
+ "Cannot convert literal of type ",
+ xla::PrimitiveType_Name(literal.shape().element_type()),
+ " to tensor of type ", DataTypeString(target_type));
+ }
+
+ TensorShape shape = XLAShapeToTensorShape(literal.shape());
+ *host_tensor = Tensor(target_type, shape);
+ size_t total_bytes = host_tensor->TotalBytes();
+ if (total_bytes > 0) {
+ const void* src_ptr = xla::LiteralUtil::InternalData(literal);
+ void* dst_ptr = DMAHelper::base(host_tensor);
+ memcpy(dst_ptr, src_ptr, total_bytes);
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
new file mode 100644
index 0000000000..3e509375ef
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utilities for working with XLA Literals.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Copies 'host_tensor' to an XLA Literal. Fails if the host_tensor has zero
+// elements or is of an unsupported type.
+Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
+
+// Copies 'literal' to 'host_tensor', which is allocated of type <target_type>.
+// Fails if the literal's primitive type !=
+// DataTypeToPrimitiveType(target_type). Note that <target_type> is not
+// derivable from the type of <literal>, because multiple tensorflow types map
+// to the same XLA type (e.g. INT32 and QINT32 both map to INT32 in
+// XLA).
+Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
+ Tensor* host_tensor);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc
new file mode 100644
index 0000000000..56993bc585
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/literal_util_test.cc
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(LiteralUtil, LiteralToHostTensor) {
+ // int64 literal can only be converted to an int64 host tensor.
+ {
+ std::vector<int64> int64_values = {1, 2, 3};
+ std::unique_ptr<xla::Literal> int64_values_literal =
+ xla::LiteralUtil::CreateR1(gtl::ArraySlice<int64>(int64_values));
+ Tensor host_tensor;
+ EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32",
+ LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor)
+ .error_message());
+ EXPECT_EQ(
+ "Cannot convert literal of type S64 to tensor of type qint32",
+ LiteralToHostTensor(*int64_values_literal, DT_QINT32, &host_tensor)
+ .error_message());
+ EXPECT_TRUE(
+ LiteralToHostTensor(*int64_values_literal, DT_INT64, &host_tensor)
+ .ok());
+ test::ExpectTensorEqual<int64>(host_tensor,
+ test::AsTensor<int64>(int64_values));
+ }
+
+ {
+ // Repeat tests with int32.
+ Tensor host_tensor;
+ std::vector<int32> int32_values = {10, 11};
+ std::unique_ptr<xla::Literal> int32_values_literal =
+ xla::LiteralUtil::CreateR1(gtl::ArraySlice<int32>(int32_values));
+ EXPECT_TRUE(
+ LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor)
+ .ok());
+ test::ExpectTensorEqual<int32>(host_tensor,
+ test::AsTensor<int32>(int32_values));
+
+ EXPECT_TRUE(
+ LiteralToHostTensor(*int32_values_literal, DT_QINT32, &host_tensor)
+ .ok());
+ std::vector<qint32> qint32_values = {10, 11};
+ test::ExpectTensorEqual<qint32>(host_tensor,
+ test::AsTensor<qint32>(qint32_values));
+
+ EXPECT_EQ("Cannot convert literal of type S32 to tensor of type int64",
+ LiteralToHostTensor(*int32_values_literal, DT_INT64, &host_tensor)
+ .error_message());
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/op_registrations.cc b/tensorflow/compiler/tf2xla/op_registrations.cc
new file mode 100644
index 0000000000..d8a4dad4b3
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/op_registrations.cc
@@ -0,0 +1,502 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Kernel registrations for XLA JIT devices.
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace {
+
+// CPU JIT device registrations.
+
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("_Arg").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("_ArrayToList"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("_ListToArray"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("_Retval").TypeConstraint("T", kCpuAllTypes));
+
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Abs").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Add").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("AddN").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("All"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Any"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("AvgPool").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("AvgPoolGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("BatchMatMul").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("BiasAdd").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("BiasAddV1").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("BiasAddGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("BroadcastGradientArgs"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Cast")
+ .TypeConstraint("SrcT", kCpuAllTypes)
+ .TypeConstraint("DstT", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Ceil").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Concat").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ConcatV2")
+ .TypeConstraint("T", kCpuAllTypes)
+ .TypeConstraint("Tidx", DT_INT32));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ConcatOffset"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Conv2D").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_CPU_XLA_JIT,
+ Name("Conv2DBackpropFilter").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_CPU_XLA_JIT,
+ Name("Conv2DBackpropInput").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_CPU_XLA_JIT,
+ Name("DepthwiseConv2dNative").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Diag").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("DiagPart").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Div").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("DynamicStitch").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Equal").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Exp").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("ExpandDims").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Fill").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Floor").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("FloorDiv").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("FloorMod").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Greater").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("GreaterEqual").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Inv").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Reciprocal").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("InvertPermutation").TypeConstraint("T", DT_INT32));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("L2Loss").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Less").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("LessEqual").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("LinSpace").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Log").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalAnd"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalNot"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("LogicalOr"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("LogSoftmax").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("LRN").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("LRNGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Maximum").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("MatMul").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("MatrixDiag").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("MatrixDiagPart").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Max").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("MaxPool").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("MaxPoolGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Mean").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Min").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Minimum").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Mod").TypeConstraint("T", kCpuIntTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Mul").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Neg").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("NotEqual").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Pack").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Pad").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Pow").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("PreventGradient").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Prod").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Range").TypeConstraint("Tidx", kCpuNumericTypes));
+// TODO(b/31361304): disabled because of XLA bugs.
+// REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomStandardNormal"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomUniform"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("RandomUniformInt"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Rank"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("RealDiv").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Relu").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Relu6").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("ReluGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Relu6Grad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Reshape").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Rsqrt").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("RsqrtGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Select").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Shape"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("ShapeN"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Sigmoid").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("SigmoidGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Sign").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("Size"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Slice").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Softmax").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_CPU_XLA_JIT,
+ Name("SoftmaxCrossEntropyWithLogits").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Softplus").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("SoftplusGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("SparseMatMul")
+ .TypeConstraint("Ta", kCpuFloatTypes)
+ .TypeConstraint("Tb", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Split").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("SplitV").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Square").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_CPU_XLA_JIT,
+ Name("SquaredDifference").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Squeeze").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Sqrt").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("StopGradient").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("StridedSlice").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("StridedSliceGrad").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Sub").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Sum").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT, Name("SymbolicGradient"));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Tanh").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("TanhGrad").TypeConstraint("T", kCpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Tile").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Transpose").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("TruncateDiv").TypeConstraint("T", kCpuIntTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("TruncateMod").TypeConstraint("T", kCpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Unpack").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("ZerosLike").TypeConstraint("T", kCpuNumericTypes));
+
+REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_CPU_XLA_JIT,
+ Name("Const").TypeConstraint("dtype",
+ kCpuAllTypes));
+REGISTER_XLA_JIT_ONLY_KERNEL(
+ DEVICE_CPU_XLA_JIT, Name("Identity").TypeConstraint("T", kCpuAllTypes));
+REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_CPU_XLA_JIT, Name("NoOp"));
+
+// GPU JIT device registrations
+
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("_Arg").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("_ArrayToList"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("_ListToArray"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("_Retval").TypeConstraint("T", kGpuAllTypes));
+
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Abs").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Add").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("AddN").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("All"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Any"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("AvgPool").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("AvgPoolGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("BatchMatMul").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("BiasAdd").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("BiasAddV1").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("BiasAddGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("BroadcastGradientArgs"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Cast")
+ .TypeConstraint("SrcT", kGpuAllTypes)
+ .TypeConstraint("DstT", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Ceil").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Concat").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("ConcatV2").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ConcatOffset"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Conv2D").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_GPU_XLA_JIT,
+ Name("Conv2DBackpropFilter").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_GPU_XLA_JIT,
+ Name("Conv2DBackpropInput").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_GPU_XLA_JIT,
+ Name("DepthwiseConv2dNative").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Diag").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("DiagPart").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Div").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("DynamicStitch").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Equal").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Exp").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("ExpandDims").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Fill").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Floor").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("FloorDiv").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("FloorMod").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Greater").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("GreaterEqual").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Inv").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Reciprocal").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("InvertPermutation").TypeConstraint("T", DT_INT32));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("L2Loss").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Less").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("LessEqual").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("LinSpace").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Log").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalAnd"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalNot"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("LogicalOr"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("LogSoftmax").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("LRN").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("LRNGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Maximum").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("MatMul").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("MatrixDiag").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("MatrixDiagPart").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Max").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("MaxPool").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("MaxPoolGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Mean").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Min").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Minimum").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Mod").TypeConstraint("T", kGpuIntTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Mul").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Neg").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("NotEqual").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Pack").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Pad").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Pow").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("PreventGradient").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Prod").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Range").TypeConstraint("Tidx", kGpuNumericTypes));
+// TODO(b/31361304): disabled because of XLA bugs.
+// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomStandardNormal"));
+// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomUniform"));
+// REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("RandomUniformInt"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Rank"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("RealDiv").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Relu").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Relu6").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("ReluGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Relu6Grad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Reshape").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Rsqrt").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("RsqrtGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Select").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Shape"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("ShapeN"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Sigmoid").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("SigmoidGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Sign").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("Size"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Slice").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Softmax").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_GPU_XLA_JIT,
+ Name("SoftmaxCrossEntropyWithLogits").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Softplus").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("SoftplusGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("SparseMatMul")
+ .TypeConstraint("Ta", kGpuFloatTypes)
+ .TypeConstraint("Tb", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Split").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("SplitV").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Square").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(
+ DEVICE_GPU_XLA_JIT,
+ Name("SquaredDifference").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Squeeze").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Sqrt").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("StopGradient").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("StridedSlice").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("StridedSliceGrad").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Sub").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Sum").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("SymbolicGradient"));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Tanh").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("TanhGrad").TypeConstraint("T", kGpuFloatTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Tile").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Transpose").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("TruncateDiv").TypeConstraint("T", kGpuIntTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("TruncateMod").TypeConstraint("T", kGpuNumericTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Unpack").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("ZerosLike").TypeConstraint("T", kGpuNumericTypes));
+
+REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_GPU_XLA_JIT,
+ Name("Const").TypeConstraint("dtype",
+ kGpuAllTypes));
+REGISTER_XLA_JIT_ONLY_KERNEL(
+ DEVICE_GPU_XLA_JIT, Name("Identity").TypeConstraint("T", kGpuAllTypes));
+REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE_GPU_XLA_JIT, Name("NoOp"));
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc
new file mode 100644
index 0000000000..f5ecb51a5b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/shape_util.cc
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Convert an XLA Shape into the equivalent TensorFlow shape.
+TensorShape XLAShapeToTensorShape(const xla::Shape& shape) {
+ TensorShape tensor_shape;
+ for (int i = 0; i < xla::ShapeUtil::Rank(shape); ++i) {
+ tensor_shape.AddDim(shape.dimensions(i));
+ }
+ return tensor_shape;
+}
+
+// Convert a TensorShape into the equivalent XLA Shape proto.
+Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
+ xla::Shape* shape) {
+ int rank = tensor_shape.dims();
+ std::vector<int64> dimensions(rank);
+ std::vector<int64> layout(rank);
+ for (int d = 0; d < rank; ++d) {
+ dimensions[d] = tensor_shape.dim_size(d);
+ }
+ // XLA uses minor-to-major; Tensorflow uses major-to-minor.
+ std::iota(layout.rbegin(), layout.rend(), 0);
+
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype, &type));
+
+ *shape = xla::ShapeUtil::MakeShapeWithLayout(type, dimensions, layout);
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/shape_util.h b/tensorflow/compiler/tf2xla/shape_util.h
new file mode 100644
index 0000000000..516dd636a9
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/shape_util.h
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utilities for working with XLA shapes.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+
+namespace tensorflow {
+
+// Convert an XLA Shape into the equivalent TensorFlow shape.
+TensorShape XLAShapeToTensorShape(const xla::Shape& shape);
+
+// Convert a TensorShape into the equivalent XLA Shape proto. Unlike Tensorflow,
+// XLA shapes include the type. Not all `dtype` values can be represented by
+// XLA, so this conversion may fail.
+Status TensorShapeToXLAShape(DataType dtype, const TensorShape& tensor_shape,
+ xla::Shape* shape);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/str_util.cc b/tensorflow/compiler/tf2xla/str_util.cc
new file mode 100644
index 0000000000..ce25d63127
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/str_util.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/str_util.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+namespace tensorflow {
+namespace str_util {
+
+void ReplaceAll(string* text, StringPiece from, StringPiece to) {
+ size_t pos = 0;
+ while ((pos = text->find(from.data(), pos, from.size())) != string::npos) {
+ text->replace(pos, from.size(), to.data(), to.size());
+ pos += to.size();
+ if (from.empty()) {
+ pos++; // Match at the beginning of the text and after every byte
+ }
+ }
+}
+
+void ReplaceAllPairs(string* text,
+ const std::vector<std::pair<string, string>>& replace) {
+ for (const std::pair<string, string>& from_to : replace) {
+ ReplaceAll(text, from_to.first, from_to.second);
+ }
+}
+
+} // namespace str_util
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/str_util.h b/tensorflow/compiler/tf2xla/str_util.h
new file mode 100644
index 0000000000..4920b1a4d4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/str_util.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// String utilities that are esoteric enough that they don't belong in
+// third_party/tensorflow/core/lib/strings/str_util.h, but are still generally
+// useful under xla.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace str_util {
+
+// Replace all non-overlapping occurrences of from with to in-place in text. If
+// from is empty, it matches at the beginning of the text and after every byte.
+void ReplaceAll(string* text, StringPiece from, StringPiece to);
+
+// Replace all non-overlapping occurrences of the given (from,to) pairs in-place
+// in text. If from is empty, it matches at the beginning of the text and after
+// every byte. Each (from,to) replacement pair is processed in the order it is
+// given.
+void ReplaceAllPairs(string* text,
+ const std::vector<std::pair<string, string>>& replace);
+
+} // namespace str_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_STR_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/str_util_test.cc b/tensorflow/compiler/tf2xla/str_util_test.cc
new file mode 100644
index 0000000000..f992007a34
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/str_util_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/str_util.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace str_util {
+
+class ReplaceAllTest : public ::testing::Test {
+ protected:
+ void ExpectReplaceAll(string text, StringPiece from, StringPiece to,
+ StringPiece want) {
+ ReplaceAll(&text, from, to);
+ EXPECT_EQ(text, want);
+ }
+};
+
+TEST_F(ReplaceAllTest, Simple) {
+ ExpectReplaceAll("", "", "", "");
+ ExpectReplaceAll("", "", "X", "X");
+ ExpectReplaceAll("", "", "XYZ", "XYZ");
+ ExpectReplaceAll("banana", "", "", "banana");
+ ExpectReplaceAll("banana", "", "_", "_b_a_n_a_n_a_");
+ ExpectReplaceAll("banana", "", "__", "__b__a__n__a__n__a__");
+ ExpectReplaceAll("banana", "a", "a", "banana");
+ ExpectReplaceAll("banana", "a", "", "bnn");
+ ExpectReplaceAll("banana", "a", "X", "bXnXnX");
+ ExpectReplaceAll("banana", "a", "XX", "bXXnXXnXX");
+ ExpectReplaceAll("banana", "an", "an", "banana");
+ ExpectReplaceAll("banana", "an", "", "ba");
+ ExpectReplaceAll("banana", "an", "X", "bXXa");
+ ExpectReplaceAll("banana", "an", "XY", "bXYXYa");
+ ExpectReplaceAll("banana", "an", "XYZ", "bXYZXYZa");
+ ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "X", "foo X baz X");
+ ExpectReplaceAll("foo {{bar}} baz {{bar}}", "{{bar}}", "ABCDEFGHIJKLMNOP",
+ "foo ABCDEFGHIJKLMNOP baz ABCDEFGHIJKLMNOP");
+}
+
+class ReplaceAllPairsTest : public ::testing::Test {
+ protected:
+ void ExpectReplaceAllPairs(
+ string text, const std::vector<std::pair<string, string>>& replace,
+ StringPiece want) {
+ ReplaceAllPairs(&text, replace);
+ EXPECT_EQ(text, want);
+ }
+};
+
+TEST_F(ReplaceAllPairsTest, Simple) {
+ ExpectReplaceAllPairs("", {}, "");
+ ExpectReplaceAllPairs("", {{"", ""}}, "");
+ ExpectReplaceAllPairs("", {{"", "X"}}, "X");
+ ExpectReplaceAllPairs("", {{"", "XYZ"}}, "XYZ");
+ ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}}, "_X_Y_Z_");
+ ExpectReplaceAllPairs("", {{"", "XYZ"}, {"", "_"}, {"_Y_", "a"}}, "_XaZ_");
+ ExpectReplaceAllPairs("banana", {}, "banana");
+ ExpectReplaceAllPairs("banana", {{"", ""}}, "banana");
+ ExpectReplaceAllPairs("banana", {{"", "_"}}, "_b_a_n_a_n_a_");
+ ExpectReplaceAllPairs("banana", {{"", "__"}}, "__b__a__n__a__n__a__");
+ ExpectReplaceAllPairs("banana", {{"a", "a"}}, "banana");
+ ExpectReplaceAllPairs("banana", {{"a", ""}}, "bnn");
+ ExpectReplaceAllPairs("banana", {{"a", "X"}}, "bXnXnX");
+ ExpectReplaceAllPairs("banana", {{"a", "XX"}}, "bXXnXXnXX");
+ ExpectReplaceAllPairs("banana", {{"a", "XX"}, {"XnX", "z"}}, "bXzzX");
+ ExpectReplaceAllPairs("a{{foo}}b{{bar}}c{{foo}}",
+ {{"{{foo}}", "0"}, {"{{bar}}", "123456789"}},
+ "a0b123456789c0");
+}
+
+} // namespace str_util
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc
new file mode 100644
index 0000000000..b54848f342
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/type_util.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) {
+ switch (data_type) {
+ case tensorflow::DT_BOOL:
+ *type = xla::PRED;
+ return Status::OK();
+ case tensorflow::DT_INT8:
+ *type = xla::S8;
+ return Status::OK();
+ case tensorflow::DT_INT16:
+ *type = xla::S16;
+ return Status::OK();
+ case tensorflow::DT_INT32:
+ *type = xla::S32;
+ return Status::OK();
+ case tensorflow::DT_INT64:
+ *type = xla::S64;
+ return Status::OK();
+ case tensorflow::DT_UINT8:
+ *type = xla::U8;
+ return Status::OK();
+ case tensorflow::DT_UINT16:
+ *type = xla::U16;
+ return Status::OK();
+ case tensorflow::DT_HALF:
+ *type = xla::F16;
+ return Status::OK();
+ case tensorflow::DT_FLOAT:
+ *type = xla::F32;
+ return Status::OK();
+ case tensorflow::DT_DOUBLE:
+ *type = xla::F64;
+ return Status::OK();
+ case tensorflow::DT_QUINT8:
+ *type = xla::U8;
+ return Status::OK();
+ case tensorflow::DT_QINT32:
+ *type = xla::S32;
+ return Status::OK();
+ default:
+ return errors::InvalidArgument(
+ "Unsupported type in DataTypeToPrimitiveType ",
+ DataTypeString(data_type));
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h
new file mode 100644
index 0000000000..bda667eb1f
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/type_util.h
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Converts a Tensorflow DataType to an XLA PrimitiveType.
+Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
new file mode 100644
index 0000000000..86a53c929e
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -0,0 +1,203 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+
+#include <functional>
+#include <memory>
+
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace tensorflow {
+
+const char* const DEVICE_CPU_XLA_JIT = "XLA_CPU_JIT";
+const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT";
+
+// The XlaCompilationAllocator doesn't actually back any Tensors with storage
+// buffers of values: instead for each Tensor it stores a
+// XlaExpression which corresponds to the XLA computation
+// represented by the Tensor.
+class XlaCompilationAllocator : public Allocator {
+ public:
+ XlaCompilationAllocator() {}
+ ~XlaCompilationAllocator() override {}
+
+ string Name() override { return "tla_jit"; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ // Regardless of the size requested, always allocate a
+ // XlaExpression. Respect the aligment request because there is
+ // alignment checking even for Tensors whose data is never
+ // accessed.
+ void* p = port::aligned_malloc(sizeof(XlaExpression), alignment);
+ XlaExpression* expression = reinterpret_cast<XlaExpression*>(p);
+ new (expression) XlaExpression();
+ return expression;
+ }
+
+ void DeallocateRaw(void* ptr) override {
+ XlaExpression* expression = reinterpret_cast<XlaExpression*>(ptr);
+ expression->~XlaExpression();
+ port::aligned_free(ptr);
+ }
+
+ // Make sure that even tensors with 0 elements have allocated
+ // buffers, so they get ids to track.
+ bool ShouldAllocateEmptyTensors() override { return true; }
+
+ void GetStats(AllocatorStats* stats) override { stats->Clear(); }
+
+ private:
+ // Don't run any constructors or destructors for complex objects,
+ // since there is no backing store for the tensor to run them
+ // on. strings are the only complex objects currently stored in
+ // Tensors. If others are added, this set of overrides must be
+ // extended to include them.
+ void RunStringCtor(string* p, size_t n) override {}
+ void RunStringDtor(string* p, size_t n) override {}
+ void RunResourceCtor(ResourceHandle* p, size_t n) override {}
+ void RunResourceDtor(ResourceHandle* p, size_t n) override {}
+};
+
+XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
+ DeviceType type)
+ : LocalDevice(options,
+ Device::BuildDeviceAttributes(
+ "", type, Bytes(256 << 20), DeviceLocality(),
+ strings::StrCat("device: XLA JIT device ", type.type())),
+ cpu_allocator()),
+ allocator_(new XlaCompilationAllocator()) {}
+
+XlaCompilationDevice::~XlaCompilationDevice() {}
+
+Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) {
+ return allocator_.get();
+}
+
+Status XlaCompilationDevice::Sync() { return Status::OK(); }
+
+Status XlaCompilationDevice::MakeTensorFromProto(
+ const TensorProto& tensor_proto, const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) {
+ return errors::InvalidArgument(
+ "Tla JIT Device should not parse tensor from proto");
+}
+
+// Is platform 'id' supported by XLA?
+static bool IsPlatformSupported(perftools::gputools::Platform::Id id) {
+ auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id);
+ if (!platform.ok()) return false;
+ return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok();
+}
+
+XlaOpRegistry::XlaOpRegistry() = default;
+XlaOpRegistry::~XlaOpRegistry() = default;
+
+/* static */ void XlaOpRegistry::RegisterJitDevice(
+ const string& device_name, const string& jit_device_name,
+ bool requires_jit) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto result = registry.jit_devices_.emplace(
+ device_name, std::make_pair(jit_device_name, requires_jit));
+ CHECK(result.second || result.first->second.first == jit_device_name);
+}
+
+/* static */ bool XlaOpRegistry::GetJitDevice(const string& device_name,
+ const string** jit_device_name,
+ bool* requires_jit) {
+ XlaOpRegistry& registry = Instance();
+
+ // Lazily register the CPU and GPU JIT devices the first time GetJitDevice is
+ // called.
+ static void* registration = [&registry]() {
+ mutex_lock lock(registry.mutex_);
+ if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) {
+ registry.jit_devices_[DEVICE_CPU] = {DEVICE_CPU_XLA_JIT, false};
+ }
+ if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) {
+ registry.jit_devices_[DEVICE_GPU] = {DEVICE_GPU_XLA_JIT, false};
+ }
+ return nullptr;
+ }();
+ (void)registration;
+
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.jit_devices_.find(device_name);
+ if (it == registry.jit_devices_.end()) return false;
+ if (jit_device_name) *jit_device_name = &it->second.first;
+ if (requires_jit) *requires_jit = it->second.second;
+ return true;
+}
+
+void XlaOpRegistry::RegisterJitKernels() {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+
+ if (registry.jit_kernels_registered_) return;
+ registry.jit_kernels_registered_ = true;
+
+ for (const auto& entry : registry.kernels_) {
+ for (const XlaKernel& k : entry.second) {
+ auto it = registry.ops_.find(k.kernel_def->op());
+ CHECK(it != registry.ops_.end()) << "Missing XLA op registration for op "
+ << k.kernel_def->op();
+ registry.kernel_registrars_.emplace_back(
+ new kernel_factory::OpKernelRegistrar(new KernelDef(*k.kernel_def),
+ "XlaJitOp", it->second));
+ }
+ }
+}
+
+std::vector<const KernelDef*> XlaOpRegistry::DeviceKernels(
+ const string& jit_device_type) {
+ std::vector<const KernelDef*> kernels;
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ for (const XlaKernel& k : registry.kernels_.at(jit_device_type)) {
+ if (!k.jit_only) {
+ kernels.push_back(k.kernel_def.get());
+ }
+ }
+ return kernels;
+}
+
+XlaOpRegistry& XlaOpRegistry::Instance() {
+ static XlaOpRegistry* r = new XlaOpRegistry;
+ return *r;
+}
+
+XlaOpRegistrar::XlaOpRegistrar(StringPiece name,
+ XlaOpRegistry::Factory factory) {
+ XlaOpRegistry& registry = XlaOpRegistry::Instance();
+ mutex_lock lock(registry.mutex_);
+ CHECK(registry.ops_.emplace(name.ToString(), factory).second)
+ << "Duplicate XLA op registration " << name;
+}
+
+XlaKernelRegistrar::XlaKernelRegistrar(bool jit_only, const KernelDef* def) {
+ XlaOpRegistry& registry = XlaOpRegistry::Instance();
+ mutex_lock lock(registry.mutex_);
+ registry.kernels_[def->device_type()].push_back(XlaOpRegistry::XlaKernel{
+ jit_only, std::unique_ptr<const KernelDef>(def)});
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
new file mode 100644
index 0000000000..f4b95b874b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -0,0 +1,214 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_
+
+#include <map>
+#include <memory>
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// Names of the XLA JIT devices. These are not user-visible, and are used
+// internally by the JIT to perform symbolic execution of a Tensorflow graph.
+
+extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT"
+extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT"
+
+constexpr std::array<DataType, 5> kCpuAllTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
+constexpr std::array<DataType, 2> kCpuIntTypes = {{DT_INT32, DT_INT64}};
+constexpr std::array<DataType, 2> kCpuFloatTypes = {{DT_FLOAT, DT_DOUBLE}};
+constexpr std::array<DataType, 4> kCpuNumericTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}};
+
+constexpr std::array<DataType, 5> kGpuAllTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}};
+constexpr std::array<DataType, 2> kGpuIntTypes = {{DT_INT32, DT_INT64}};
+constexpr std::array<DataType, 2> kGpuFloatTypes = {{DT_FLOAT, DT_DOUBLE}};
+constexpr std::array<DataType, 4> kGpuNumericTypes = {
+ {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE}};
+
+// Class is declared and defined in tla_jit_device.cc, reference
+// included here only so the XlaCompilationDevice allocator_ member can be
+// defined.
+class XlaCompilationAllocator;
+
+// Deliberately don't register the device factory because we *never*
+// want soft placement to put Ops on an JIT device. Tests can include
+// the tla_jit_test_deps target which registers the factory, and when
+// using JIT in practice, the device is created manually not using a
+// factory.
+
+// This is a 'dummy' TensorFlow device that is only used to execute a
+// subgraph of XLA compilation Ops to construct a compiled version
+// of the subgraph's computation. It has a 'dummy' allocator that
+// backs each Tensor with metadata indicating the computation the
+// Tensor represents.
+class XlaCompilationDevice : public LocalDevice {
+ public:
+ XlaCompilationDevice(const SessionOptions& options, DeviceType type);
+
+ ~XlaCompilationDevice() override;
+
+ Allocator* GetAllocator(AllocatorAttributes attr) override;
+
+ Status Sync() override;
+
+ Status MakeTensorFromProto(const TensorProto& tensor_proto,
+ const AllocatorAttributes alloc_attrs,
+ Tensor* tensor) override;
+
+ private:
+ std::unique_ptr<XlaCompilationAllocator> allocator_;
+};
+
+// Class that manages registrations of operators and devices for the XLA JIT.
+// Not thread-safe.
+class XlaOpRegistry {
+ public:
+ typedef OpKernel* (*Factory)(OpKernelConstruction*);
+
+ // Registers 'jit_device_name' as the JIT device corresponding to
+ // 'device_name'. If 'requires_jit' is true, then operators placed on this
+ // device must be JIT-compiled. Dies if a conflicting registration already
+ // exists.
+ static void RegisterJitDevice(const string& device_name,
+ const string& jit_device_name,
+ bool requires_jit);
+
+ // Returns the JIT device name associated with 'device_name', setting
+ // 'jit_device_name' and 'requires_jit', if they are not null. Returns false
+ // and leaves 'jit_device_name' and 'requires_jit' unchanged if no matching
+ // JIT device is registered.
+ static bool GetJitDevice(const string& device_name,
+ const string** jit_device_name, bool* requires_jit);
+
+ // Registers all JIT kernels on JIT devices, if not already registered.
+ // Does nothing otherwise.
+ static void RegisterJitKernels();
+
+ // Returns KernelDefs for JIT ops registered on 'jit_device_type'.
+ // Does not include kernels registered using REGISTER_XLA_JIT_ONLY_KERNEL.
+ static std::vector<const KernelDef*> DeviceKernels(
+ const string& jit_device_type);
+
+ private:
+ friend class XlaKernelRegistrar;
+ friend class XlaOpRegistrar;
+
+ static XlaOpRegistry& Instance();
+
+ XlaOpRegistry();
+ ~XlaOpRegistry();
+
+ mutex mutex_;
+
+ // Map from Tensorflow device names to the corresponding JIT device names.
+ std::unordered_map<string, std::pair<string, bool>> jit_devices_
+ GUARDED_BY(mutex_);
+
+ // Map from operator name to OpKernel factory, populated by REGISTER_XLA_OP.
+ std::unordered_map<string, Factory> ops_ GUARDED_BY(mutex_);
+
+ // Have we already registered the JIT kernels on the JIT devices?
+ bool jit_kernels_registered_ = false;
+
+ struct XlaKernel {
+ // Should this kernel be registered only on JIT devices, without a dummy
+ // kernel registered on the corresponding XLA device?
+ bool jit_only;
+
+ // KernelDef as built by REGISTER_XLA_KERNEL.
+ std::unique_ptr<const KernelDef> kernel_def;
+ };
+
+ // Map from JIT device name to a vector of XLA kernel descriptors.
+ std::unordered_map<string, std::vector<XlaKernel>> kernels_
+ GUARDED_BY(mutex_);
+
+ // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
+ // registrations created by RegisterJitKernels() and RegisterDeviceKernels().
+ std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
+ kernel_registrars_ GUARDED_BY(mutex_);
+};
+
+// REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
+// REGISTER_XLA_OP("Add", AddOp);
+// where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
+//
+// We don't use a variadic macro here because we don't expect JIT operators to
+// be templated.
+
+#define REGISTER_XLA_OP(NAME, OP) \
+ REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
+
+// REGISTER_XLA_KERNEL() associates an XLA OpKernel with a particular device and
+// set of type constraints, e.g.,
+// REGISTER_XLA_KERNEL(DEVICE_XLA_CPU_JIT,
+// Name("Relu").TypeConstraint("T", DT_FLOAT));
+//
+// REGISTER_XLA_JIT_ONLY_KERNEL is similar to REGISTER_XLA_KERNEL(), but causes
+// XlaOpRegistry::RegisterDeviceKernels() to ignore the kernel.
+
+#define REGISTER_XLA_KERNEL(DEVICE, BUILDER) \
+ REGISTER_XLA_KERNEL_UNIQ_HELPER(__COUNTER__, DEVICE, BUILDER, false)
+
+#define REGISTER_XLA_JIT_ONLY_KERNEL(DEVICE, BUILDER) \
+ REGISTER_XLA_KERNEL_UNIQ_HELPER(__COUNTER__, DEVICE, BUILDER, true)
+
+// Implementation details.
+
+class XlaOpRegistrar {
+ public:
+ XlaOpRegistrar(StringPiece name, XlaOpRegistry::Factory factory);
+};
+
+#define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, NAME, OP) \
+ REGISTER_XLA_OP_UNIQ(COUNTER, NAME, OP)
+
+#define REGISTER_XLA_OP_UNIQ(CTR, NAME, OP) \
+ static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
+ NAME, [](::tensorflow::OpKernelConstruction* context) \
+ -> ::tensorflow::OpKernel* { return new OP(context); });
+
+// Implementation details.
+class XlaKernelRegistrar {
+ public:
+ XlaKernelRegistrar(bool jit_only, const KernelDef* def);
+};
+
+#define REGISTER_XLA_KERNEL_UNIQ_HELPER(COUNTER, DEVICE, BUILDER, JIT_ONLY) \
+ REGISTER_XLA_KERNEL_UNIQ(COUNTER, DEVICE, BUILDER, JIT_ONLY)
+
+#define REGISTER_XLA_KERNEL_UNIQ(CTR, DEVICE, BUILDER, JIT_ONLY) \
+ static ::tensorflow::XlaKernelRegistrar \
+ xla_kernel_registrar__body__##CTR##__object( \
+ JIT_ONLY, \
+ ::tensorflow::register_kernel::BUILDER.Device(DEVICE).Build());
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
new file mode 100644
index 0000000000..e46c2a3148
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -0,0 +1,405 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/version.h"
+
+namespace tensorflow {
+
+namespace {
+
+bool HasRetval(const Graph& graph) {
+ for (const Node* n : graph.nodes()) {
+ if (n->type_string() == "_Retval") return true;
+ }
+ return false;
+}
+
+Status CheckSignature(const DataTypeVector& tf_types,
+ const xla::Shape& xla_shape) {
+ if (xla::ShapeUtil::IsTuple(xla_shape)) {
+ if (xla::ShapeUtil::TupleElementCount(xla_shape) != tf_types.size()) {
+ return errors::Internal("XLA shape has ",
+ xla::ShapeUtil::TupleElementCount(xla_shape),
+ " elements while function has ", tf_types.size());
+ }
+ for (int i = 0; i < tf_types.size(); ++i) {
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[i], &type));
+ if (type !=
+ xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type()) {
+ return errors::Internal(
+ "element ", i, " has XLA type ",
+ xla::ShapeUtil::GetTupleElementShape(xla_shape, i).element_type(),
+ " and TensorFlow type ", DataTypeString(tf_types[i]));
+ }
+ }
+ } else {
+ if (tf_types.size() != 1) {
+ return errors::Internal("Expected singleton type, got ", tf_types.size(),
+ " types");
+ }
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(tf_types[0], &type));
+ if (type != xla_shape.element_type()) {
+ return errors::Internal("singleton element has XLA type ",
+ xla_shape.element_type(), " and TensorFlow type ",
+ DataTypeString(tf_types[0]));
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+XlaCompiler::XlaCompiler(const XlaCompiler::Options& options)
+ : client_(options.client),
+ allow_cpu_custom_calls_(options.allow_cpu_custom_calls),
+ local_executable_has_hybrid_result_(
+ options.local_executable_has_hybrid_result),
+ next_step_id_(1),
+ device_(new XlaCompilationDevice(SessionOptions(), options.device_type)),
+ device_mgr_({device_}) {}
+
+XlaCompiler::~XlaCompiler() = default;
+
+int64 XlaCompiler::NextStepId() {
+ mutex_lock l(mu_);
+ return next_step_id_++;
+}
+
+Status XlaCompiler::CompileFunction(
+ FunctionLibraryRuntime* flr, const NameAttrList& function,
+ const std::vector<XlaCompiler::Argument>& args,
+ XlaCompiler::CompilationResult* result) {
+ const string function_id = Canonicalize(function.name(), function.attr());
+ VLOG(1) << "XlaCompiler::CompileFunction " << function_id;
+
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(
+ flr->Instantiate(function.name(), function.attr(), &handle));
+
+ const FunctionBody* fbody = flr->GetFunctionBody(handle);
+ CHECK(fbody);
+
+ return CompileFunctionBody(flr, *fbody, function_id, args,
+ /*use_tuple_arg=*/false, result);
+}
+
+Status XlaCompiler::CompileSubComputation(FunctionLibraryRuntime* flr,
+ const NameAttrList& function,
+ const xla::Shape& input_shape,
+ const xla::Shape& output_shape,
+ xla::Computation* computation) {
+ const string function_id = Canonicalize(function.name(), function.attr());
+ VLOG(1) << "XlaCompiler::CompileSubComputation " << function_id;
+
+ FunctionLibraryRuntime::Handle handle;
+ TF_RETURN_IF_ERROR(
+ flr->Instantiate(function.name(), function.attr(), &handle));
+
+ const FunctionBody* fbody = flr->GetFunctionBody(handle);
+ CHECK(fbody);
+
+ TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, input_shape));
+ TF_RETURN_IF_ERROR(CheckSignature(fbody->ret_types, output_shape));
+
+ const bool use_tuple_arg = xla::ShapeUtil::IsTuple(input_shape);
+
+ std::vector<XlaCompiler::Argument> args(fbody->arg_types.size());
+ if (use_tuple_arg) {
+ for (int i = 0; i < args.size(); ++i) {
+ xla::Shape xla_shape =
+ xla::ShapeUtil::GetTupleElementShape(input_shape, i);
+ args[i].type = fbody->arg_types[i];
+ args[i].shape = XLAShapeToTensorShape(xla_shape);
+ args[i].parameter = i;
+ }
+ } else {
+ args[0].type = fbody->arg_types[0];
+ args[0].shape = XLAShapeToTensorShape(input_shape);
+ args[0].parameter = 0;
+ }
+
+ CompilationResult result;
+ TF_RETURN_IF_ERROR(CompileFunctionBody(flr, *fbody, function_id, args,
+ use_tuple_arg, &result));
+
+ if (!xla::ShapeUtil::Compatible(result.xla_output_shape, output_shape)) {
+ return errors::Internal("output shape mismatch from compilation");
+ }
+ *computation = std::move(result.computation);
+
+ return Status::OK();
+}
+
+Status XlaCompiler::CompileFunctionBody(
+ FunctionLibraryRuntime* flr, const FunctionBody& fbody,
+ const string& function_id, const std::vector<XlaCompiler::Argument>& args,
+ bool use_tuple_arg, XlaCompiler::CompilationResult* result) {
+ VLOG(1) << "XlaCompiler::CompileFunctionBody " << function_id;
+
+ std::unique_ptr<Graph> graph(new Graph(flr->GetFunctionLibraryDefinition()));
+ CopyGraph(*fbody.graph, graph.get());
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile(
+ strings::StrCat("xla_jit_raw_input_", function_id), *graph);
+ }
+
+ if (!HasRetval(*graph)) {
+ VLOG(1) << "Graph has no retvals. Skipping compilation.";
+ return Status::OK();
+ }
+
+ // Optimize the graph to before running throught the translator.
+ // TODO(pbar) The constant folder currently does not simplify int32 operations
+ // for devices other than CPU.
+ OptimizerOptions opts;
+ GraphOptimizer optimizer(opts);
+ Graph* g = graph.release();
+ OptimizeGraph(flr, &g);
+ graph.reset(g);
+
+ if (VLOG_IS_ON(1)) {
+ dump_graph::DumpGraphToFile(
+ strings::StrCat("xla_jit_final_graph_", function_id), *graph);
+ }
+
+ VLOG(1) << "====================================================";
+ TF_RETURN_IF_ERROR(CompileGraph(function_id, std::move(graph), flr, args,
+ use_tuple_arg, result));
+ VLOG(1) << "====================================================";
+
+ return Status::OK();
+}
+
+Status XlaCompiler::BuildExecutable(
+ const XlaCompiler::CompilationResult& result,
+ std::unique_ptr<xla::LocalExecutable>* executable) {
+ VLOG(2) << "Compiling to local executable";
+ xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
+
+ std::vector<const xla::Shape*> argument_layouts(
+ result.xla_input_shapes.size());
+ for (int i = 0; i < result.xla_input_shapes.size(); ++i) {
+ argument_layouts[i] = &result.xla_input_shapes[i].second;
+ }
+ if (result.requires_runtime_context) {
+ // The final arg is the XlaLocalRuntimeContext*.
+ argument_layouts.push_back(&opaque_shape);
+ }
+ xla::LocalClient* local_client = static_cast<xla::LocalClient*>(client());
+ xla::ExecutableBuildOptions build_options;
+ build_options.set_device_ordinal(local_client->default_device_ordinal());
+ build_options.set_platform(local_client->platform());
+ build_options.set_result_layout(result.xla_output_shape);
+ build_options.set_has_hybrid_result(local_executable_has_hybrid_result_);
+
+ auto compile_result = local_client->Compile(result.computation,
+ argument_layouts, build_options);
+ if (!compile_result.ok()) {
+ return compile_result.status();
+ }
+ *executable = std::move(compile_result.ValueOrDie());
+ return Status::OK();
+}
+
+namespace {
+
+Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
+ XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
+ int64 step_id) {
+ // Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the
+ // resource manager takes ownership via Create, and unrefs via Cleanup. We
+ // explicitly add a reference to ensure the refcount at entry is maintained at
+ // all exit points; Create and Cleanup are always called in this function.
+ //
+ // The Executor requires us to use ScopedStepContainer. We wrap it in a
+ // unique_ptr so we can capture the cleanup status in the end.
+ xla_context->Ref();
+ Status cleanup_status;
+ auto step_container = xla::MakeUnique<ScopedStepContainer>(
+ step_id, [&cleanup_status, device](const string& name) {
+ cleanup_status = device->resource_manager()->Cleanup(name);
+ });
+ TF_RETURN_IF_ERROR(device->resource_manager()->Create(
+ step_container->name(), XlaContext::kXlaContextResourceName,
+ xla_context));
+
+ // Create a LocalExecutor that will own and run the graph.
+ LocalExecutorParams exec_params;
+ exec_params.device = device;
+ exec_params.function_library = flib;
+ exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) {
+ return flib->CreateKernel(ndef, kernel);
+ };
+ exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; };
+ Executor* exec_ptr = nullptr;
+ TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr));
+ std::unique_ptr<Executor> exec(exec_ptr);
+ // At this point ownership of the graph has been transferred to exec.
+
+ auto runner = [](Executor::Args::Closure c) {
+ // TODO(misard) Temporarily just schedule c eagerly while we
+ // decide what to do about the fact that the ComputationBuilder is
+ // thread-compatible, but we don't really want Op writers to have
+ // to remember to acquire a lock around every call to
+ // ComputationBuilder. One possibility is to add the (generally
+ // useful) ability to run a single-threaded Executor based on an
+ // option in LocalExecutorParams. Another is to automagically
+ // acquire a lock around ComputationBuilder calls using some
+ // wrapper or RAII funny business.
+ c();
+ };
+
+ // Run the graph symbolically, turning the graph into an XLA computation.
+ Executor::Args exec_args;
+ exec_args.step_id = step_id;
+ exec_args.step_container = step_container.get();
+ exec_args.runner = runner;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ exec->Run(exec_args),
+ "Conversion from TensorFlow graph to XLA computation failed.");
+
+ // Explicitly clean up the step container, to capture the cleanup status.
+ step_container.reset();
+ return cleanup_status;
+}
+
+} // namespace
+
+Status XlaCompiler::CompileGraph(string const& name,
+ std::unique_ptr<Graph> graph,
+ FunctionLibraryRuntime* flib,
+ const std::vector<XlaCompiler::Argument>& args,
+ bool use_tuple_arg,
+ CompilationResult* result) {
+ VLOG(1) << "Executing graph symbolically to populate ComputationBuilder.";
+
+ // Converts the input shapes into xla::Shape instances.
+ result->xla_input_shapes.reserve(args.size());
+ for (int i = 0; i < args.size(); ++i) {
+ if (args[i].parameter < 0) {
+ continue;
+ }
+ result->xla_input_shapes.push_back(std::make_pair(i, xla::Shape()));
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(
+ args[i].type, args[i].shape, &result->xla_input_shapes.back().second));
+ }
+
+ XlaContext* xla_context =
+ new XlaContext(client(), name, allow_cpu_custom_calls_);
+ core::ScopedUnref xla_context_unref(xla_context);
+
+ TF_RETURN_IF_ERROR(xla_context->BuildArguments(args, use_tuple_arg));
+
+ TF_RETURN_IF_ERROR(
+ ExecuteGraph(xla_context, std::move(graph), device_, flib, NextStepId()));
+
+ std::vector<XlaContext::ConstRetVal> compile_time_constants;
+ int num_nonconst_outputs;
+ TF_RETURN_IF_ERROR(xla_context->CollectResults(
+ &result->computation, &result->requires_runtime_context,
+ &compile_time_constants, &num_nonconst_outputs));
+
+ result->outputs.resize(compile_time_constants.size() + num_nonconst_outputs);
+ for (const auto& c : compile_time_constants) {
+ if (!c.status.ok()) {
+ Status constant_status = c.status;
+ errors::AppendToMessage(&constant_status,
+ "Failed evaluating constant XLA return "
+ "value ",
+ c.index);
+ return constant_status;
+ }
+ if (c.index >= result->outputs.size()) {
+ return errors::InvalidArgument("Invalid argument index ", c.index);
+ }
+ OutputDescription& output = result->outputs[c.index];
+ output.shape = c.value.shape();
+ output.is_constant = true;
+ output.constant_value = c.value;
+ }
+
+ if (result->computation.IsNull()) {
+ return Status::OK();
+ }
+
+ // Compute the output shapes, if there is a computation with non-constant
+ // outputs.
+ auto computation_shape = client()->GetComputationShape(result->computation);
+ if (!computation_shape.ok()) {
+ return computation_shape.status();
+ }
+
+ result->xla_output_shape.Swap(
+ computation_shape.ValueOrDie()->mutable_result());
+
+ auto num_non_constant_outputs =
+ (xla::ShapeUtil::IsTuple(result->xla_output_shape))
+ ? xla::ShapeUtil::TupleElementCount(result->xla_output_shape)
+ : 1;
+ // Tensorflow expects a major-to-minor order of results.
+ if (1 == num_non_constant_outputs) {
+ xla::Shape& s = result->xla_output_shape;
+ auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major();
+ minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0);
+ std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
+ } else {
+ for (xla::Shape& s : *result->xla_output_shape.mutable_tuple_shapes()) {
+ auto& minor_to_major = *s.mutable_layout()->mutable_minor_to_major();
+ minor_to_major.Resize(xla::ShapeUtil::Rank(s), 0);
+ std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0);
+ }
+ }
+
+ // Converts the output shapes to TensorShapes.
+ int computation_output = 0;
+ for (int i = 0; i < result->outputs.size(); ++i) {
+ if (!result->outputs[i].is_constant) {
+ CHECK_LT(computation_output, num_non_constant_outputs);
+ if (num_non_constant_outputs > 1) {
+ result->outputs[i].shape =
+ XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
+ result->xla_output_shape, computation_output));
+ } else {
+ result->outputs[i].shape =
+ XLAShapeToTensorShape(result->xla_output_shape);
+ }
+ ++computation_output;
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
new file mode 100644
index 0000000000..0b882d60a1
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -0,0 +1,203 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// The XlaCompiler class is responsible for compilation of a self-contained
+// subgraph of a TensorFlow computation using the XLA linear algebra runtime.
+// It does a symbolic execution of the graph starting from specific input
+// shapes, using a JIT device to convert operators into XLA computations.
+//
+// It is typically invoked from an `_XlaLaunch` operator once the shapes
+// of all input parameters to the computation are known. This is
+// because the symbolic execution requires known shapes for all operations.
+class XlaCompiler {
+ public:
+ // Describes how to derive the value of each _Arg node in the graph/function
+ // being compiled. Each argument must be either a parameter of the generated
+ // XLA computation (parameter >= 0), or a compile time constant
+ // (parameter < 0).
+ struct Argument {
+ // The type of the argument.
+ DataType type;
+
+ // The shape of the argument.
+ TensorShape shape;
+
+ // The parameter number of this argument to the XLA computation. < 0
+ // means this is a compile-time constant argument.
+ int parameter;
+
+ // The value of the argument, if it is a compile-time constant. Must be a
+ // host-memory tensor.
+ Tensor constant_value;
+
+ // The name of this argument, used for debugging.
+ string name;
+ };
+
+ struct OutputDescription {
+ // Shape of the output.
+ TensorShape shape;
+
+ // Constant output value, if known to be constant at JIT compilation time.
+ // 'Tensor' is in host memory.
+ bool is_constant = false;
+ Tensor constant_value;
+ };
+
+ struct CompilationResult {
+ // Vector of (Tensorflow input number, XLA shape) pairs that describe
+ // the arguments of the compiled XLA computation. (Because of constant
+ // inputs, the arguments to the XLA computation are a subset of the
+ // inputs passed to the JIT.)
+ std::vector<std::pair<int, xla::Shape>> xla_input_shapes;
+
+ // Does the computation require the local runtime context to be passed as
+ // the last argument?
+ bool requires_runtime_context = false;
+
+ // Output shape in XLA format. This is a tuple if and only if
+ // there are multiple non-constant outputs.
+ xla::Shape xla_output_shape;
+
+ // TensorFlow shapes of outputs, together with the values of any
+ // constant arguments. Vector indexed by Tensorflow _Retval number,
+ // containing both constant and non-constant arguments.
+ std::vector<OutputDescription> outputs;
+
+ // The XLA computation built from the tensorflow subgraph. May be null
+ // if the output consists solely of compile-time constants.
+ xla::Computation computation;
+ };
+
+ struct Options {
+ // Name of the compilation device to use.
+ DeviceType device_type = DeviceType("");
+
+ xla::Client* client = nullptr;
+
+ // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall()
+ // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed
+ // to the computation.
+ bool allow_cpu_custom_calls = false;
+
+ // If 'local_executable_has_hybrid_result', the top-level pointers of the
+ // result tuple of compiled programs are stored in host memory and the
+ // nested buffers in device memory, otherwise the whole result tuple is
+ // stored in device memory.
+ bool local_executable_has_hybrid_result = false;
+ };
+
+ explicit XlaCompiler(const Options& options);
+ ~XlaCompiler();
+
+ // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation.
+ // `args` describes the arguments to the function, each of which must either
+ // be a parameter to the XLA computation or a compile-time constant.
+ // Writes the compiled output to `result`.
+ //
+ // The generated XLA computation returns a tuple containing only the
+ // non-constant outputs as a function of the input arguments. Constant
+ // arguments are returned as host memory tensors in the output list and are
+ // not included in the XLA computation's outputs. The XLA computation is
+ // null if there are no data-dependent outputs.
+ Status CompileFunction(FunctionLibraryRuntime* flr,
+ const NameAttrList& fn_name_attrs,
+ const std::vector<Argument>& args,
+ CompilationResult* result);
+
+ // Compiles a tensorflow::Graph into an xla::Computation.
+ // Similar to CompileFunction, but takes a Graph as input rather than a
+ // function.
+ // If `use_tuple_arg` is true, the compilation takes all of its arguments as
+ // a single tuple.
+ Status CompileGraph(string const& name, std::unique_ptr<Graph> graph,
+ FunctionLibraryRuntime* flr,
+ const std::vector<Argument>& args, bool use_tuple_arg,
+ CompilationResult* result);
+
+ // Helper function that compiles a function to an XLA computation suitable
+ // for use as a subroutine in other Computations, e.g., the body of a
+ // While loop.
+ //
+ // The emitted Computation takes a single input parameter with
+ // input_shape. If this is a tuple then the tuple element shapes
+ // must match the types of the function's _Arg nodes. If input_shape
+ // is not a tuple then the function must have a single _Arg node
+ // with the same type as input_shape. The shapes of the _Arg values
+ // will be compiled to match input_shape.
+ //
+ // The emitted Computation also returns a single value. If output_shape is a
+ // tuple the tuple elements' types and shapes must match the compiled
+ // function's _Retval nodes. If output_shape is not a tuple the
+ // function must have a single _Retval node with the correct type
+ // (and shape after compilation).
+ Status CompileSubComputation(FunctionLibraryRuntime* flr,
+ const NameAttrList& fn_name_attrs,
+ const xla::Shape& input_shape,
+ const xla::Shape& output_shape,
+ xla::Computation* computation);
+
+ // Takes <*result>, which has been compiled from a Tensorflow subgraph to a
+ // XLA computation already, and generates an XLA LocalExecutable `executable`.
+ Status BuildExecutable(const CompilationResult& result,
+ std::unique_ptr<xla::LocalExecutable>* executable);
+
+ xla::Client* client() const { return client_; }
+ XlaCompilationDevice* device() const { return device_; }
+ const DeviceMgr* device_mgr() const { return &device_mgr_; }
+
+ private:
+ // Does the real work of Compile() and CompileToComputation().
+ Status CompileFunctionBody(FunctionLibraryRuntime* function_library,
+ const FunctionBody& function_body,
+ const string& name,
+ const std::vector<Argument>& args,
+ bool use_tuple_arg, CompilationResult* result);
+
+ xla::Client* client_; // Not owned.
+ const bool allow_cpu_custom_calls_;
+ const bool local_executable_has_hybrid_result_;
+
+ // Returns the next step sequence number.
+ int64 NextStepId();
+
+ mutex mu_;
+
+ // Internal sequence number for steps executed on the compilation device.
+ int64 next_step_id_ GUARDED_BY(mu_);
+
+ XlaCompilationDevice* device_; // Owned by device_mgr_
+ DeviceMgr device_mgr_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
new file mode 100644
index 0000000000..ad8fc3f205
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -0,0 +1,331 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+XlaExpression::XlaExpression() : has_constant_value_(false) {}
+
+void XlaExpression::set_handle(const xla::ComputationDataHandle& h) {
+ handle_ = h;
+}
+const xla::ComputationDataHandle& XlaExpression::handle() const {
+ return handle_;
+}
+
+void XlaExpression::set_constant_value(Tensor value) {
+ has_constant_value_ = true;
+ constant_value_ = std::move(value);
+}
+
+const char XlaContext::kXlaContextResourceName[] = "_xla_context";
+
+// Looks up the context associated with the current step. It is stored
+// in a resource container managed by the device.
+/* static */ XlaContext& XlaContext::Get(const OpKernelContext* ctx) {
+ // When an Op kernel wants to use an XLA JIT context, the
+ // per-step context is looked up in the resource manager. The
+ // JIT will prepopulate the JITContext.
+ XlaContext* context;
+ TF_CHECK_OK(ctx->resource_manager()->Lookup(
+ ctx->step_container()->name(), kXlaContextResourceName, &context));
+ // The resource manager handed us a fresh reference to 'context', but retains
+ // a reference itself so the context won't be freed. The resource manager will
+ // outlive the JIT compilation.
+ context->Unref();
+ return *context;
+}
+
+Status XlaContext::BuildArguments(std::vector<XlaCompiler::Argument> args,
+ bool use_tuple_arg) {
+ args_ = std::move(args);
+ use_tuple_arg_ = use_tuple_arg;
+
+ // Compute the number of parameters, verify that they are sequential starting
+ // from 0
+ num_parameters_ = 0;
+ for (const XlaCompiler::Argument& arg : args_) {
+ if (arg.parameter < 0) continue;
+ if (num_parameters_ != arg.parameter) {
+ return errors::InvalidArgument(
+ "Parameter numbers to JIT compilation are not consecutive starting "
+ "from 0");
+ }
+ ++num_parameters_;
+
+ if (arg.shape.num_elements() == 0) {
+ return errors::InvalidArgument(
+ "Non-constant argument must have a non-zero number of elements.");
+ }
+ }
+ if (num_parameters_ == 0) return Status::OK();
+
+ parameters_.resize(num_parameters_);
+
+ std::vector<xla::Shape> parameter_shapes(num_parameters_);
+ for (int i = 0; i < args_.size(); ++i) {
+ const XlaCompiler::Argument& arg = args_[i];
+ if (arg.parameter < 0) continue;
+ // Computes the shapes of non-constant arguments.
+ xla::PrimitiveType type;
+ TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(arg.type, &type));
+ xla::ShapeUtil::PopulateShape(type, arg.shape.dim_sizes(),
+ &parameter_shapes[arg.parameter]);
+ }
+
+ if (use_tuple_arg_ && num_parameters_ > 0) {
+ xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(parameter_shapes);
+ xla::ComputationDataHandle tuple =
+ builder().Parameter(0, tuple_shape, "arg_tuple");
+ for (int i = 0; i < args_.size(); ++i) {
+ const XlaCompiler::Argument& arg = args_[i];
+ if (arg.parameter < 0) continue;
+ parameters_[arg.parameter] =
+ builder().GetTupleElement(tuple, arg.parameter);
+ }
+ } else {
+ for (int i = 0; i < args_.size(); ++i) {
+ const XlaCompiler::Argument& arg = args_[i];
+ if (arg.parameter < 0) continue;
+ parameters_[arg.parameter] =
+ builder().Parameter(arg.parameter, parameter_shapes[arg.parameter],
+ strings::StrCat("arg", i));
+ }
+ }
+ return Status::OK();
+}
+
+Status XlaContext::CollectResults(
+ xla::Computation* computation, bool* requires_runtime_context,
+ std::vector<ConstRetVal>* compile_time_constants,
+ int* num_nonconst_outputs) {
+ mutex_lock l(mu_);
+
+ bool return_singleton = (1 == retval_.size());
+
+ xla::ComputationDataHandle handle;
+ if (return_singleton) {
+ handle = retval_[0].second;
+
+ // TODO(b/31775371): to workaround bug, add a no-op computation that is
+ // guaranteed to be constructed after all of the formal parameters to the
+ // computation.
+ handle = builder().GetTupleElement(builder().Tuple({handle}), 0);
+
+ // Ensure that the retval is returned even if another computation
+ // was mistakenly placed on the ComputationBuilder.
+ TF_CHECK_OK(builder().SetReturnValue(handle));
+ } else {
+ if (!retval_.empty()) {
+ // There is at least one data-dependent expression: combine them
+ // into a Tuple in index order before compiling.
+ VLOG(1) << "Making the retval tuple.";
+ std::sort(retval_.begin(), retval_.end(),
+ [](const std::pair<int, xla::ComputationDataHandle>& a,
+ const std::pair<int, xla::ComputationDataHandle>& b) {
+ return a.first < b.first;
+ });
+ std::vector<xla::ComputationDataHandle> elems;
+ elems.reserve(retval_.size());
+ for (const std::pair<int, xla::ComputationDataHandle>& r : retval_) {
+ elems.push_back(r.second);
+ }
+ // Make a tuple from the vector of handles.
+ handle = builder().Tuple(elems);
+ }
+ }
+
+ if (handle.handle() > 0) {
+ // Build the full computation. The return value is the handle
+ // constructed above.
+ xla::StatusOr<xla::Computation> computation_status = builder().Build();
+ if (!computation_status.ok()) {
+ return computation_status.status();
+ }
+ *computation = computation_status.ConsumeValueOrDie();
+ }
+
+ // Make sure the compile time constants are in RetVal index order.
+ std::sort(compile_time_constant_.begin(), compile_time_constant_.end(),
+ [](const ConstRetVal& a, const ConstRetVal& b) {
+ return a.index < b.index;
+ });
+
+ // Fill in the result details and return.
+ *compile_time_constants = std::move(compile_time_constant_);
+ *requires_runtime_context = has_context_parameter_;
+ *num_nonconst_outputs = retval_.size();
+ return Status::OK();
+}
+
+XlaContext::XlaContext(xla::Client* client, const string& computation_name,
+ bool allow_cpu_custom_calls)
+ : xla_builder_(client, computation_name),
+ allow_cpu_custom_calls_(allow_cpu_custom_calls) {}
+
+const xla::ComputationDataHandle&
+XlaContext::GetOrCreateRuntimeContextParameter() {
+ mutex_lock lock(mu_);
+ CHECK(allow_cpu_custom_calls_);
+ CHECK(!use_tuple_arg_);
+ if (has_context_parameter_) return context_parameter_;
+ has_context_parameter_ = true;
+ context_parameter_ = xla_builder_.Parameter(
+ num_parameters_, xla::ShapeUtil::MakeOpaqueShape(), "tf_context");
+ return context_parameter_;
+}
+
+string XlaContext::DebugString() { return "TLA JIT context"; }
+
+// This is called by the Retval Op to associate a computed value
+// with a specific return value of the subgraph.
+void XlaContext::AddRetval(int retval_index,
+ const xla::ComputationDataHandle& handle) {
+ VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
+ // Add the return value to the list being built up. The executor
+ // is multi-threaded so this has to happen under the
+ // lock.
+ mutex_lock l(mu_);
+ retval_.emplace_back(retval_index, handle);
+}
+
+Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
+ const xla::Literal& literal) {
+ VLOG(1) << "Adding retval index " << retval_index
+ << " with non-data-dependent tensor to XLA computation";
+ ConstRetVal value;
+ value.index = retval_index;
+ TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value.value));
+ mutex_lock l(mu_);
+ compile_time_constant_.push_back(std::move(value));
+ return Status::OK();
+}
+
+/* static */ const XlaExpression* XlaContext::CastExpressionFromTensor(
+ const Tensor& tensor) {
+ const XlaExpression* expression =
+ reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
+ CHECK_NE(expression->handle().handle(), 0);
+ VLOG(1) << "Fetched T" << expression->handle().handle();
+ return expression;
+}
+
+/* static */ XlaExpression* XlaContext::CastExpressionFromUninitializedTensor(
+ Tensor* tensor) {
+ const XlaExpression* expression =
+ reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
+ CHECK_EQ(expression->handle().handle(), 0);
+ return const_cast<XlaExpression*>(expression);
+}
+
+/* static */ const XlaExpression* XlaContext::GetExpressionFromTensor(
+ const Tensor& tensor) {
+ return CastExpressionFromTensor(tensor);
+}
+
+/* static */ const xla::ComputationDataHandle&
+XlaContext::GetComputationFromTensor(const Tensor& tensor) {
+ return CastExpressionFromTensor(tensor)->handle();
+}
+
+xla::ComputationBuilder& XlaContext::builder() { return xla_builder_; }
+
+const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
+ return LookupOrCreate(type, &max_func_, [this, type] {
+ const string type_string = DataTypeString(type);
+ VLOG(1) << "Building Max() for " << type_string;
+ xla::ComputationBuilder b(builder().client(), "max<" + type_string + ">");
+ xla::PrimitiveType xla_type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
+ auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ b.Max(x, y);
+ return b.Build().ConsumeValueOrDie();
+ });
+}
+
+const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
+ return LookupOrCreate(type, &add_func_, [this, type] {
+ const string type_string = DataTypeString(type);
+ VLOG(1) << "Building Add() for " << type_string;
+ xla::ComputationBuilder b(builder().client(), "add<" + type_string + ">");
+ xla::PrimitiveType xla_type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
+ auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ b.Add(x, y);
+ return b.Build().ConsumeValueOrDie();
+ });
+}
+
+const xla::Computation* XlaContext::GetOrCreateSigmoid(const DataType type) {
+ return LookupOrCreate(type, &sigmoid_func_, [this, type] {
+ const string type_string = DataTypeString(type);
+ VLOG(1) << "Building Sigmoid() for " << type_string;
+ xla::ComputationBuilder b(builder().client(),
+ "sigmoid<" + type_string + ">");
+ xla::PrimitiveType xla_type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
+ auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto one = b.ConstantLiteral(xla::LiteralUtil::One(xla_type));
+ auto minus_one = b.Neg(one);
+ b.Div(one, b.Add(b.Exp(b.Mul(x, minus_one)), one));
+ return b.Build().ConsumeValueOrDie();
+ });
+}
+
+const xla::Computation* XlaContext::LookupOrCreate(
+ DataType type, ComputationMap* out,
+ const std::function<xla::Computation()>& create) {
+ {
+ mutex_lock l(mu_);
+ const auto& entry = (*out)[type];
+ if (!entry.IsNull()) {
+ return &entry;
+ }
+ }
+ auto new_entry = create();
+ {
+ mutex_lock l(mu_);
+ // Somebody else might have made one concurrently.
+ auto& entry = (*out)[type];
+ if (entry.IsNull()) {
+ entry = std::move(new_entry);
+ }
+ return &entry;
+ }
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
new file mode 100644
index 0000000000..b0464025f7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines the contexts used to represent XLA JIT computatations.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace tensorflow {
+
+// A XlaExpression wraps an XLA computation. Each Tensor sent
+// along an edge during XLA JIT compilation represents a
+// XlaExpression, and the shape of the Tensor matches the shape of
+// the subcomputation in the ComputationDataHandle. Each
+// expression is either a constant, an unbound parameter, or a
+// function of previously-compiled expressions.
+class XlaExpression {
+ public:
+ XlaExpression();
+
+ // handle() stores the XLA handle of the computation that the
+ // expression represents.
+ void set_handle(const xla::ComputationDataHandle& h);
+ const xla::ComputationDataHandle& handle() const;
+
+ void set_constant_value(Tensor value);
+ bool has_constant_value() const { return has_constant_value_; }
+ const Tensor& constant_value() const { return constant_value_; }
+
+ private:
+ friend class XlaContext;
+
+ // The XLA handle of the expression's computation.
+ xla::ComputationDataHandle handle_;
+
+ // If this expression is a constant with a known value, 'constant_value' is a
+ // host-memory Tensor containing the value. Used to avoid invoking XLA for
+ // expressions that are trivially constant.
+ bool has_constant_value_;
+ Tensor constant_value_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
+};
+
+// The XlaContext is the datastructure accessible from
+// OpKernelContexts when evaluating a subgraph of Ops for JIT
+// compilation by XLA. When an Op is executed during JIT
+// compilation the input Tensors to the Op store handles to
+// subcomputations compiled by earlier Ops in the subgraph. The Op can
+// retrieve these subcomputations by calling either
+// GetExpressionFromTensor, which returns the XlaExpression holding
+// the subcomputation; or EvaluateAsConstant which returns an XLA
+// literal of the result of the subcomputation or an error status if
+// the subcomputation depends on unbound parameters. The Op may then
+// use the ComputationBuilder available from XlaContext::builder()
+// to compile one or more functions of the inputs into
+// ComputationDataHandles. The handles can be stored as new
+// expressions corresponding to the outputs of the Op by calling
+// CreateOutputTensorFromComputation or
+// CreateConstantOutputTensor. The *only* correct way to allocate an
+// output tensor is using one of the preceding two methods, since they
+// ensure there is a valid XlaExpression backing the output
+// tensor. No Op should ever call allocate_output or allocate_temp
+// directly on the OpKernelContext. It is permissible to pass a tensor
+// from an Op input to an output (e.g. call ctx->set_output with a
+// tensor passed as an input). As an example, the softmax Op produces
+// output from input as follows:
+//
+// XlaContext& tc = XlaContext::Get(context);
+// xla::ComputationBuilder& b = tc.builder();
+// xla::ComputationDataHandle logits =
+// tc.GetComputationFromTensor(logits_in));
+// ... The softmax computation uses the builder b to compute a
+// xla::ComputationDataHandle softmax holding the desired output.
+// ...
+// OP_REQUIRES_OK(context, tc.CreateOutputTensorFromComputation(
+// context, 0, logits_in.shape().dim_sizes(),
+// softmax));
+//
+class XlaContext : public ResourceBase {
+ public:
+ // If a retval can be evaluated at JIT time it is returned as a
+ // Literal in a ConstRetVal struct as part of the ComputationResult.
+ // TODO(misard) reconcile this with the duplicate data structure in
+ // the XlaCompilationCache class.
+ struct ConstRetVal {
+ // The index of the RetVal corresponding to this constant literal.
+ int index;
+ // If status is not OK, value's data is undefined.
+ Status status;
+ // The value of the RetVal evaluated at JIT compilation
+ // time. value.shape() always gives the correct shape of the
+ // RetVal. If !status.ok() then value's data is undefined, otherwise the
+ // Tensor buffer is allocated in CPU memory.
+ Tensor value;
+ };
+
+
+ // Virtual method defined by ResourceBase.
+ string DebugString() override;
+
+ // Retrieve the XlaContext corresponding to a step's JIT compilation.
+ static XlaContext& Get(const OpKernelContext* ctx);
+ static XlaContext& Get(const XlaOpKernelContext* ctx) {
+ return Get(ctx->op_kernel_context());
+ }
+
+ // Create a new XlaContext.
+ XlaContext(xla::Client* client, const string& computation_name,
+ bool allow_cpu_custom_calls);
+
+ // Builds XLA computations for each of the arguments.
+ // Should only be called once to initialize the arguments. Not thread-safe.
+ Status BuildArguments(std::vector<XlaCompiler::Argument> arguments,
+ bool use_tuple_arg) TF_MUST_USE_RESULT;
+
+ // Returns the results of the symbolic computation that have accumulated in
+ // the XlaContext. After CollectResults() is called, the context is left in
+ // an invalid state and must not be reused.
+ // Sets `requires_runtime_context` if the emitted computation requires a
+ // runtime context argument. `compile_time_constants` describes any non
+ // data-dependent results of the computation. `num_nonconst_ouputs` is set to
+ // the number of outputs of the `computation`.
+ Status CollectResults(xla::Computation* computation,
+ bool* requires_runtime_context,
+ std::vector<ConstRetVal>* compile_time_constants,
+ int* num_nonconst_outputs);
+
+ // This is called by the Retval Op to associate a computed value
+ // with a specific return value of the subgraph.
+ void AddRetval(int retval_index, const xla::ComputationDataHandle& handle);
+
+ // As for Retval, but for return values that are compile-time constants.
+ Status AddConstRetval(int retval_index, DataType dtype,
+ const xla::Literal& literal);
+
+ // Retrieves the ComputationDataHandle from an input Tensor to an Op. This
+ // computation was constructed by an Op that executed previously and
+ // created the output Tensor using CreateOutputTensorFromComputation
+ // or CreateConstantOutputTensor.
+ static const xla::ComputationDataHandle& GetComputationFromTensor(
+ const Tensor& tensor);
+
+ // Returns the ComputationBuilder that Ops use for compiling new
+ // expressions.
+ xla::ComputationBuilder& builder();
+
+ const std::vector<XlaCompiler::Argument>& args() const { return args_; }
+ xla::ComputationDataHandle parameter(int num) { return parameters_[num]; }
+
+ // Get the runtime context parameter, adding one if it does not already exist.
+ // Dies if not compiling a local executable.
+ const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter();
+
+ bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
+
+ // Get an XLA lambda to compute Max. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateMax(const DataType type);
+
+ // Get an XLA lambda to compute Add. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateAdd(const DataType type);
+
+ // Get an XLA lambda to compute Sigmoid. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateSigmoid(const DataType type);
+
+ // The name of the XlaContext resource during symbolic graph execution.
+ static const char kXlaContextResourceName[];
+
+ private:
+ friend class XlaOpKernelContext;
+
+ // This method is used to retrieve an expression that was allocated by
+ // a previous Op.
+ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
+
+ // This method is used to retrieve an uninitialized expression from a
+ // newly-allocated tensor.
+ static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor);
+
+ // Retrieves the expression from an input Tensor to an Op. This
+ // expression was constructed by an Op that executed previously and
+ // created the output Tensor using CreateOutputTensorFromComputation
+ // or CreateConstantOutputTensor.
+ static const XlaExpression* GetExpressionFromTensor(const Tensor& tensor);
+
+ mutable mutex mu_;
+
+ // The ComputationBuilder used to construct the subgraph's compiled
+ // representation.
+ xla::ComputationBuilder xla_builder_ GUARDED_BY(mu_);
+
+ // Number of XLA Parameters, not counting the context parameter, if any.
+ int num_parameters_;
+
+ // Arguments to the JIT compilation, both compile-time constant arguments and
+ // runtime parameters.
+ std::vector<XlaCompiler::Argument> args_;
+ bool use_tuple_arg_ = false;
+
+ // Runtime parameters to the XLA computation. Does not include
+ // compile-time constant arguments.
+ std::vector<xla::ComputationDataHandle> parameters_;
+
+ // Allow ops to emit CustomCall operations for CPU.
+ const bool allow_cpu_custom_calls_;
+
+ // When 'has_context_parameter_' is true, this is the computation handle
+ // for an additional final parameter to the computation, through which will be
+ // passed a XlaLocalRuntimeContext* at runtime. Created on demand by
+ // GetOrCreateRuntimeContextParameter().
+ bool has_context_parameter_ GUARDED_BY(mu_) = false;
+ xla::ComputationDataHandle context_parameter_ GUARDED_BY(mu_);
+
+ // The data-dependent return values of the computation.
+ std::vector<std::pair<int, xla::ComputationDataHandle>> retval_
+ GUARDED_BY(mu_);
+
+ // The non-data-dependent return values of the computation.
+ std::vector<ConstRetVal> compile_time_constant_ GUARDED_BY(mu_);
+
+ // Cache of prebuilt computations indexed by their type.
+ using ComputationMap = std::map<DataType, xla::Computation>;
+
+ // Finds the value for the given type in out map if it already
+ // exists or makes a new value with create function and keeps it the
+ // map. The returned value != nullptr and is owned by the map.
+ const xla::Computation* LookupOrCreate(
+ DataType type, ComputationMap* out,
+ const std::function<xla::Computation()>& create) LOCKS_EXCLUDED(mu_);
+
+ // Cached computation to compute Max of two elements, specialized by type.
+ ComputationMap max_func_ GUARDED_BY(mu_);
+
+ // Cached computation to compute Sum of two elements, specialized by type.
+ ComputationMap add_func_ GUARDED_BY(mu_);
+
+ // Cached computation to compute Sigmoid of an element, specialized by type.
+ ComputationMap sigmoid_func_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaContext);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_CONTEXT_H_
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
new file mode 100644
index 0000000000..efb0facf7b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines helper routines for Tla JIT compilation.
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::MinValue(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::Zero(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::One(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
+ xla::ComputationBuilder* b, DataType data_type, int64 value) {
+ xla::Literal literal;
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ switch (type) {
+ case xla::U8:
+ literal = *xla::LiteralUtil::CreateR0<uint8>(value);
+ break;
+ case xla::U32:
+ literal = *xla::LiteralUtil::CreateR0<uint32>(value);
+ break;
+ case xla::U64:
+ literal = *xla::LiteralUtil::CreateR0<uint64>(value);
+ break;
+ case xla::S8:
+ literal = *xla::LiteralUtil::CreateR0<int8>(value);
+ break;
+ case xla::S32:
+ literal = *xla::LiteralUtil::CreateR0<int32>(value);
+ break;
+ case xla::S64:
+ literal = *xla::LiteralUtil::CreateR0<int64>(value);
+ break;
+ case xla::F32:
+ literal = *xla::LiteralUtil::CreateR0<float>(value);
+ break;
+ case xla::F64:
+ literal = *xla::LiteralUtil::CreateR0<double>(value);
+ break;
+ case xla::PRED:
+ LOG(FATAL) << "pred element type is not integral";
+ case xla::S16:
+ case xla::U16:
+ LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case xla::F16:
+ LOG(FATAL) << "f16 literals not yet implemented";
+ case xla::TUPLE:
+ LOG(FATAL) << "tuple element type is not integral";
+ case xla::OPAQUE:
+ LOG(FATAL) << "opaque element type is not integral";
+ default:
+ LOG(FATAL) << "unhandled element type " << type;
+ }
+ return b->ConstantLiteral(literal);
+}
+
+xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
+ DataType data_type,
+ double value) {
+ xla::Literal literal;
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ switch (type) {
+ case xla::F32:
+ return b->ConstantR0<float>(static_cast<float>(value));
+ break;
+ case xla::F64:
+ return b->ConstantR0<double>(value);
+ break;
+ default:
+ LOG(FATAL) << "unhandled element type " << type;
+ }
+}
+
+/* static */ Status XlaHelpers::ReshapeLiteral(
+ const xla::Literal& input, gtl::ArraySlice<int64> dimensions,
+ xla::Literal* output) {
+ if (xla::ShapeUtil::IsTuple(input.shape())) {
+ return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
+ }
+ xla::Shape shape =
+ xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
+ int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
+ int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
+ if (elements_before != elements_after) {
+ return errors::InvalidArgument(
+ "Shapes before and after ReshapeLiteral have different numbers of "
+ "elements.");
+ }
+
+ *output = input;
+ output->mutable_shape()->Swap(&shape);
+ return Status::OK();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h
new file mode 100644
index 0000000000..353ed02edd
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_helpers.h
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines helper routines for the TLA device.
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
+
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+// Helper methods for building XLA computations.
+class XlaHelpers {
+ public:
+ // Returns a handle representing the minimum value of a scalar
+ // element of data_type.
+ static xla::ComputationDataHandle MinValue(xla::ComputationBuilder* b,
+ DataType data_type);
+
+ // Returns a handle representing the maximum value of a scalar
+ // element of data_type.
+ static xla::ComputationDataHandle MaxValue(xla::ComputationBuilder* b,
+ DataType data_type);
+
+ // Returns a handle representing the zero value of a scalar
+ // element of data_type.
+ static xla::ComputationDataHandle Zero(xla::ComputationBuilder* b,
+ DataType data_type);
+
+ // Returns a handle representing the one value of a scalar
+ // element of data_type.
+ static xla::ComputationDataHandle One(xla::ComputationBuilder* b,
+ DataType data_type);
+
+ // Returns a handle representing the given value of an integer scalar
+ // element of data_type.
+ // Note that unlike One and Zero, does not work on boolean types.
+ static xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* b,
+ DataType data_type,
+ int64 value);
+
+ // Returns a handle representing the given value of a floating-point scalar
+ // element of data_type.
+ static xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* b,
+ DataType data_type,
+ double value);
+
+ // Reshapes literal 'input' to have 'shape'. Both the original shape and
+ // 'shape' must contain the same number of elements.
+ static Status ReshapeLiteral(const xla::Literal& input,
+ gtl::ArraySlice<int64> shape,
+ xla::Literal* output);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h
new file mode 100644
index 0000000000..cd773d64ed
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's
+// actually used. E.g. some ahead-of-time compiled computations don't need a
+// thread pool.
+namespace Eigen {
+class ThreadPoolDevice;
+}
+
+namespace tensorflow {
+
+// An instance of this class is passed to each call from tensorflow into a
+// compiled XLA computation. See xla_launch_ops.cc.
+struct XlaLocalRuntimeContext {
+ public:
+ XlaLocalRuntimeContext() {}
+
+ // Kernels implemented using custom call ops set this if they encounter an
+ // error. The error is checked after the entire XLA computation is
+ // complete.
+ //
+ // error+error_msg are used instead of Status to reduce the binary size
+ // overhead for ahead-of-time compiled binaries.
+ bool error = false;
+ string error_msg;
+
+ // Kernels that need a thread pool can get it from here.
+ const Eigen::ThreadPoolDevice* thread_pool = nullptr;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
new file mode 100644
index 0000000000..3883b907b4
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -0,0 +1,253 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+
+#include <numeric>
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+
+namespace tensorflow {
+
+XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
+ : context_(context) {}
+
+bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
+ return context_->ValidateInputsAreSameShape(op);
+}
+
+xla::ComputationBuilder* XlaOpKernelContext::builder() const {
+ return &XlaContext::Get(this).builder();
+}
+
+const xla::ComputationDataHandle& XlaOpKernelContext::Input(int index) {
+ return XlaContext::GetComputationFromTensor(context_->input(index));
+}
+
+TensorShape XlaOpKernelContext::InputShape(int index) {
+ return context_->input(index).shape();
+}
+
+Status XlaOpKernelContext::ConstantInput(int index,
+ xla::Literal* constant_literal) {
+ return ConstantInputReshaped(
+ index, context_->input(index).shape().dim_sizes(), constant_literal);
+}
+
+Status XlaOpKernelContext::ConstantInputReshaped(
+ int index, gtl::ArraySlice<int64> new_dims,
+ xla::Literal* constant_literal) {
+ const Tensor& tensor = context_->input(index);
+ TensorShape new_shape(new_dims);
+ if (tensor.NumElements() != new_shape.num_elements()) {
+ return errors::InvalidArgument(
+ context_->op_kernel().name(), " input ", index, " has shape ",
+ tensor.shape().DebugString(),
+ " but was asked to be reshaped to incompatible shape ",
+ new_shape.DebugString());
+ }
+ const XlaExpression* expression =
+ XlaContext::CastExpressionFromTensor(tensor);
+
+ // If the tensor has a known constant value, there is no need to invoke XLA.
+ if (expression->has_constant_value()) {
+ Tensor temp(tensor.dtype());
+ if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
+ // This should never happen. The constant should have a shape compatible
+ // with the enclosing Tensor.
+ return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
+ }
+ return HostTensorToLiteral(temp, constant_literal);
+ }
+
+ // Make sure we treat zero-element tensors as constant.
+ if (new_shape.num_elements() == 0) {
+ Tensor temp(tensor.dtype(), new_shape);
+ return HostTensorToLiteral(temp, constant_literal);
+ }
+
+ xla::ComputationDataHandle handle = expression->handle();
+ if (new_shape != tensor.shape()) {
+ // Reshape the handle to the desired shape.
+ handle = builder()->Reshape(handle, new_shape.dim_sizes());
+ }
+
+ // The XLA layout is specified minor to major, and TensorFlow's minor
+ // dimension is the last one.
+ std::vector<int64> layout_indices(new_shape.dims());
+ std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+ xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
+
+ // Ask the XLA compiler to evaluate the data handle to a literal.
+ xla::StatusOr<std::unique_ptr<xla::GlobalData>> computed =
+ builder()->ComputeConstant(handle, &layout);
+ if (!computed.ok()) {
+ return errors::InvalidArgument(
+ "Error evaluating ", context_->op_kernel().name(), " input ", index,
+ ": ", computed.status().error_message());
+ }
+ // Fetch the literal from the compiler service.
+ xla::StatusOr<std::unique_ptr<xla::Literal>> constant =
+ builder()->client()->Transfer(*computed.ValueOrDie());
+ if (!constant.ok()) {
+ return errors::InvalidArgument(
+ "Error evaluating ", context_->op_kernel().name(), " input ", index,
+ ": ", constant.status().error_message());
+ }
+ constant_literal->Swap(constant.ValueOrDie().get());
+ return Status::OK();
+}
+
+// Converts an int32 or int64 1D literal to an int64 vector.
+static Status LiteralToInt64Vector(const xla::Literal& literal,
+ std::vector<int64>* out) {
+ if (xla::ShapeUtil::Rank(literal.shape()) != 1) {
+ return errors::InvalidArgument("value is not 1D");
+ }
+ int64 size = xla::ShapeUtil::ElementsIn(literal.shape());
+ if (literal.shape().element_type() == xla::S32) {
+ for (int64 i = 0; i < size; ++i) {
+ out->push_back(xla::LiteralUtil::Get<int32>(literal, {i}));
+ }
+ } else if (literal.shape().element_type() == xla::S64) {
+ for (int64 i = 0; i < size; ++i) {
+ out->push_back(xla::LiteralUtil::Get<int64>(literal, {i}));
+ }
+ } else {
+ return errors::InvalidArgument("value must be either int32 or int64");
+ }
+ return Status::OK();
+}
+
+Status XlaOpKernelContext::ConstantInputAsIntVector(int index,
+ std::vector<int64>* out) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ return LiteralToInt64Vector(literal, out);
+}
+
+// TODO(phawkins): validate that the dimensions form a valid shape, fail
+// gracefully if they do not.
+Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape) {
+ xla::Literal literal;
+ TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
+ std::vector<int64> dims;
+ TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
+ *shape = TensorShape(dims);
+ return Status::OK();
+}
+
+Status XlaOpKernelContext::InputList(
+ StringPiece name, std::vector<xla::ComputationDataHandle>* handles,
+ std::vector<TensorShape>* shapes) {
+ OpInputList inputs;
+ TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
+ handles->clear();
+ shapes->clear();
+ for (const Tensor& input : inputs) {
+ handles->push_back(XlaContext::GetComputationFromTensor(input));
+ shapes->push_back(input.shape());
+ }
+ return Status::OK();
+}
+
+Status XlaOpKernelContext::ConstantInputList(
+ StringPiece name, std::vector<xla::Literal>* outputs) {
+ int start, stop;
+ TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
+ outputs->resize(stop - start);
+ for (int i = start; i < stop; ++i) {
+ TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i]));
+ }
+ return Status::OK();
+}
+
+void XlaOpKernelContext::SetOutput(int index,
+ const xla::ComputationDataHandle& handle) {
+ // Makes the host Tensor that will refer to the expression.
+ Tensor* output = nullptr;
+ auto shape = builder()->GetShape(handle);
+ if (!shape.ok()) {
+ SetStatus(shape.status());
+ return;
+ }
+
+ // The step's default allocator is the dummy XlaCompilationAllocator which
+ // simply allocates a metadata buffer to hold the expression to which it
+ // corresponds.
+ OP_REQUIRES_OK(
+ context_,
+ context_->allocate_output(
+ index, XLAShapeToTensorShape(*shape.ValueOrDie()), &output));
+
+ // The expression is stored in the tensor's data buffer. Fill in the
+ // fields now.
+ XlaExpression* expression =
+ XlaContext::CastExpressionFromUninitializedTensor(output);
+ expression->set_handle(handle);
+}
+
+void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
+ const TensorShape& shape = constant.shape();
+
+ xla::Literal literal;
+ OP_REQUIRES_OK(context_, HostTensorToLiteral(constant, &literal));
+ xla::ComputationDataHandle handle = builder()->ConstantLiteral(literal);
+
+ // Make the Tensor that will refer to the expression.
+ Tensor* output = nullptr;
+ // The step's default allocator is the dummy XlaCompilationAllocator which
+ // simply allocates a metadata buffer to hold the expression to which it
+ // corresponds.
+ OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
+
+ // The expression is stored in the tensor's data buffer. Fill in the
+ // fields now.
+ XlaExpression* expression =
+ XlaContext::CastExpressionFromUninitializedTensor(output);
+ expression->set_handle(handle);
+ expression->set_constant_value(constant);
+}
+
+void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); }
+void XlaOpKernelContext::CtxFailureWithWarning(Status s) {
+ context_->CtxFailureWithWarning(s);
+}
+
+const xla::Computation* XlaOpKernelContext::GetOrCreateMax(
+ const DataType type) {
+ return XlaContext::Get(context_).GetOrCreateMax(type);
+}
+
+const xla::Computation* XlaOpKernelContext::GetOrCreateAdd(
+ const DataType type) {
+ return XlaContext::Get(context_).GetOrCreateAdd(type);
+}
+
+const xla::Computation* XlaOpKernelContext::GetOrCreateSigmoid(
+ const DataType type) {
+ return XlaContext::Get(context_).GetOrCreateSigmoid(type);
+}
+
+XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
+
+void XlaOpKernel::Compute(OpKernelContext* context) {
+ XlaOpKernelContext xla_context(context);
+ Compile(&xla_context);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
new file mode 100644
index 0000000000..0c614005be
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -0,0 +1,174 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+class XlaOpKernelContext;
+
+// Implementations of operators that generate XLA code should usually subclass
+// XlaOpKernel and implement the Compile() method. Unlike a regular OpKernel,
+// an XlaOpKernel produces and consumes symbolic values during compilation.
+//
+// See the comments in xla_context.h for more details.
+class XlaOpKernel : public OpKernel {
+ public:
+ explicit XlaOpKernel(OpKernelConstruction* construction);
+
+ // Subclasses should implement Compile(), much as standard OpKernels implement
+ // Compute().
+ virtual void Compile(XlaOpKernelContext* context) = 0;
+
+ private:
+ void Compute(OpKernelContext* context) final;
+};
+
+// The context passed to the Compile() method of XlaOpKernel. An
+// XlaOpKernelContext is a variant of the standard OpKernel class, tailored for
+// implementing operators that perform symbolic execution as part of the XLA
+// compiler. The key difference is that XlaOpKernelContext produces and consumes
+// data as XLA computations, rather than as standard Tensors. (Under the hood,
+// symbolic execution communicates using special Tensors, but that is an
+// implementation detail that this class hides.)
+class XlaOpKernelContext {
+ public:
+ explicit XlaOpKernelContext(OpKernelContext* context);
+
+ // Returns the XLA ComputationBuilder containing the output of compilation.
+ xla::ComputationBuilder* builder() const;
+
+ // Inputs
+
+ // Returns the number of inputs to the operator.
+ int num_inputs() const { return context_->num_inputs(); }
+
+ // Returns the type of input 'index'.
+ DataType input_type(int index) { return context_->input(index).dtype(); }
+
+ // Returns the shape of input 'index'.
+ TensorShape InputShape(int index);
+
+ // Returns input 'index' as a ComputationDataHandle. Unlike
+ // OpKernelContext::Input returns a symbolic value rather than a concrete
+ // Tensor.
+ const xla::ComputationDataHandle& Input(int index);
+
+ // Returns true if all inputs are the same shape, otherwise sets the
+ // status to a non-OK value and returns false.
+ // Usage: if (!context->ValidateInputsAreSameShape(this)) return;
+ bool ValidateInputsAreSameShape(OpKernel* op) TF_MUST_USE_RESULT;
+
+ // Returns the named list-valued immutable input in "list", as
+ // defined in the OpDef. If the named output is not list-valued,
+ // returns a one-element list.
+ Status InputList(StringPiece name,
+ std::vector<xla::ComputationDataHandle>* handles,
+ std::vector<TensorShape>* shapes);
+
+ // Helper methods for constant inputs.
+
+ // Evaluates input 'index' and stores it in '*constant_literal'. If the
+ // expression cannot be evaluated, e.g., because it depends on unbound
+ // parameters, returns a non-OK status.
+ Status ConstantInput(int index, xla::Literal* constant_literal);
+
+ // Evaluates input 'index', reshapes it to 'new_shape' if new_shape !=
+ // InputShape(index), and stores it in '*constant_literal'. If the input
+ // cannot be evaluated, e.g., because it depends on unbound parameters,
+ // returns a non-Ok status. If InputShape(index).num_elements() !=
+ // new_shape.num_elements(), returns an error status.
+ Status ConstantInputReshaped(int index, gtl::ArraySlice<int64> new_shape,
+ xla::Literal* constant_literal);
+
+ // Converts a constant 1D int32 or int64 tensor into a vector of int64s.
+ Status ConstantInputAsIntVector(int index, std::vector<int64>* out);
+
+ // Converts a constant 1D int32 or int64 tensor into a TensorShape.
+ Status ConstantInputAsShape(int index, TensorShape* shape);
+
+ // Returns the named list-valued immutable input in "list", as
+ // defined in the OpDef. If the named output is not list-valued,
+ // returns a one-element list.
+ Status ConstantInputList(StringPiece name,
+ std::vector<xla::Literal>* literals);
+
+ // Outputs
+
+ int num_outputs() const { return context_->num_outputs(); }
+ DataType expected_output_dtype(int index) const {
+ return context_->expected_output_dtype(index);
+ }
+
+ // Sets output 'index' to the ComputationDataHandle 'handle'.
+ // All outputs should be set using SetOutput and SetConstantOutput, not
+ // via the underlying OpKernelContext.
+ void SetOutput(int index, const xla::ComputationDataHandle& handle);
+
+ // Sets output 'index' to compile-time constant 'host_tensor', where
+ // 'host_tensor' is a tensor in host memory. It is preferable to use
+ // SetConstantOutput where possible.
+ void SetConstantOutput(int index, const Tensor& host_tensor);
+
+ // Status handling.
+ void SetStatus(const Status& status) { context_->SetStatus(status); }
+ Status status() { return context_->status(); }
+
+ // Helper routines for the OP_REQUIRES macros
+ void CtxFailure(Status s);
+ void CtxFailureWithWarning(Status s);
+
+ // If this kernel invocation is within a function execution,
+ // call_frame() returns the call frame for the function call.
+ FunctionCallFrame* call_frame() const { return context_->call_frame(); }
+
+ FunctionLibraryRuntime* function_library() const {
+ return context_->function_library();
+ }
+
+ const OpKernel& op_kernel() const { return context_->op_kernel(); }
+
+ // Returns the underlying OpKernelContext. Use rarely.
+ OpKernelContext* op_kernel_context() const { return context_; }
+
+ // TODO(phawkins): find a better home for these helpers.
+
+ // Get an XLA lambda to compute Max. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateMax(const DataType type);
+
+ // Get an XLA lambda to compute Add. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateAdd(const DataType type);
+
+ // Get an XLA lambda to compute Sigmoid. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateSigmoid(const DataType type);
+
+ private:
+ OpKernelContext* const context_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
diff --git a/tensorflow/compiler/xla/.clang-format b/tensorflow/compiler/xla/.clang-format
new file mode 100644
index 0000000000..c2aa867556
--- /dev/null
+++ b/tensorflow/compiler/xla/.clang-format
@@ -0,0 +1,3 @@
+BasedOnStyle: Google
+Language: Cpp
+PointerBindsToType: true
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
new file mode 100644
index 0000000000..caa5141da4
--- /dev/null
+++ b/tensorflow/compiler/xla/BUILD
@@ -0,0 +1,561 @@
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow:internal"])
+
+package_group(
+ name = "friends",
+ packages = [
+ "//tensorflow/compiler/...",
+ ],
+)
+
+package_group(
+ name = "internal",
+ packages = [
+ "//tensorflow/compiler/xla/...",
+ ],
+)
+
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+xla_proto_library(
+ name = "xla_data_proto",
+ srcs = ["xla_data.proto"],
+ visibility = ["//visibility:public"],
+)
+
+xla_proto_library(
+ name = "xla_proto",
+ srcs = ["xla.proto"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":xla_data_proto",
+ "//tensorflow/compiler/xla/service:session_proto",
+ ],
+)
+
+cc_library(
+ name = "types",
+ hdrs = ["types.h"],
+ visibility = [":friends"],
+ deps = ["//tensorflow/core:lib"],
+)
+
+cc_library(
+ name = "service_interface",
+ srcs = [],
+ hdrs = ["service_interface.h"],
+ visibility = [":friends"],
+ deps = [
+ ":xla_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "status_macros",
+ srcs = ["status_macros.cc"],
+ hdrs = ["status_macros.h"],
+ visibility = [":friends"],
+ deps = [
+ ":statusor",
+ ":types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "status_macros_test",
+ size = "small",
+ srcs = ["status_macros_test.cc"],
+ deps = [
+ ":status_macros",
+ ":statusor",
+ ":test_helpers",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "status",
+ hdrs = ["status.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "statusor",
+ srcs = ["statusor.cc"],
+ hdrs = ["statusor.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":status",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_test(
+ name = "statusor_test",
+ size = "small",
+ srcs = ["statusor_test.cc"],
+ deps = [
+ ":statusor",
+ ":types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "util",
+ srcs = ["util.cc"],
+ hdrs = [
+ "map_util.h",
+ "ptr_util.h",
+ "util.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":status",
+ ":types",
+ ":xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:util_flags",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "protobuf_util",
+ srcs = ["protobuf_util.cc"],
+ hdrs = [
+ "protobuf_util.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "util_test",
+ srcs = ["util_test.cc"],
+ deps = [
+ ":types",
+ ":util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "shape_util",
+ srcs = [
+ "index_util.cc",
+ "layout_util.cc",
+ "primitive_util.cc",
+ "shape_util.cc",
+ ],
+ hdrs = [
+ "index_util.h",
+ "layout_util.h",
+ "primitive_util.h",
+ "shape_util.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":protobuf_util",
+ ":status_macros",
+ ":statusor",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ ],
+)
+
+cc_test(
+ name = "shape_util_test",
+ srcs = ["shape_util_test.cc"],
+ deps = [
+ ":shape_util",
+ ":test_helpers",
+ ":types",
+ ":util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "layout_util_test",
+ srcs = ["layout_util_test.cc"],
+ deps = [
+ ":shape_util",
+ ":test_helpers",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "index_util_test",
+ srcs = ["index_util_test.cc"],
+ deps = [
+ ":shape_util",
+ ":test_helpers",
+ ":xla_data_proto",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "literal_util",
+ srcs = ["literal_util.cc"],
+ hdrs = ["literal_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":array2d",
+ ":array3d",
+ ":array4d",
+ ":shape_util",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "literal_util_test",
+ srcs = ["literal_util_test.cc"],
+ deps = [
+ ":array3d",
+ ":array4d",
+ ":literal_util",
+ ":shape_util",
+ ":test_helpers",
+ ":types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "device_util",
+ hdrs = ["device_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "array2d",
+ srcs = ["array2d.cc"],
+ hdrs = ["array2d.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ ":util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "array2d_test",
+ srcs = ["array2d_test.cc"],
+ deps = [
+ ":array2d",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "array3d",
+ hdrs = ["array3d.h"],
+ visibility = [":friends"],
+ deps = [
+ ":types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "array3d_test",
+ srcs = ["array3d_test.cc"],
+ deps = [
+ ":array3d",
+ ":types",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "array4d",
+ hdrs = ["array4d.h"],
+ visibility = [":friends"],
+ deps = [
+ ":array2d",
+ ":types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "array4d_test",
+ srcs = ["array4d_test.cc"],
+ deps = [
+ ":array4d",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "executable_run_options",
+ srcs = ["executable_run_options.cc"],
+ hdrs = ["executable_run_options.h"],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "differential_set",
+ hdrs = ["differential_set.h"],
+ visibility = [":internal"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "differential_set_test",
+ srcs = ["differential_set_test.cc"],
+ deps = [
+ ":differential_set",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "packed_literal_reader",
+ srcs = ["packed_literal_reader.cc"],
+ hdrs = ["packed_literal_reader.h"],
+ visibility = [":internal"],
+ deps = [
+ ":literal_util",
+ ":shape_util",
+ ":status_macros",
+ ":statusor",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "test_helpers",
+ testonly = 1,
+ srcs = ["test_helpers.cc"],
+ hdrs = ["test_helpers.h"],
+ visibility = [":internal"],
+ deps = [
+ ":statusor",
+ ":types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "text_literal_reader",
+ srcs = ["text_literal_reader.cc"],
+ hdrs = ["text_literal_reader.h"],
+ visibility = [":internal"],
+ deps = [
+ ":literal_util",
+ ":shape_util",
+ ":status_macros",
+ ":statusor",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_test(
+ name = "text_literal_reader_test",
+ srcs = ["text_literal_reader_test.cc"],
+ deps = [
+ ":literal_util",
+ ":shape_util",
+ ":text_literal_reader",
+ ":types",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "text_literal_writer",
+ srcs = ["text_literal_writer.cc"],
+ hdrs = ["text_literal_writer.h"],
+ visibility = [":internal"],
+ deps = [
+ ":literal_util",
+ ":shape_util",
+ ":status_macros",
+ ":types",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "text_literal_writer_test",
+ srcs = ["text_literal_writer_test.cc"],
+ deps = [
+ ":literal_util",
+ ":test_helpers",
+ ":text_literal_writer",
+ ":types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "shape_tree",
+ hdrs = ["shape_tree.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":shape_util",
+ ":status_macros",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "shape_tree_test",
+ srcs = ["shape_tree_test.cc"],
+ deps = [
+ ":shape_tree",
+ ":shape_util",
+ ":xla_data_proto",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "shape_layout",
+ srcs = ["shape_layout.cc"],
+ hdrs = ["shape_layout.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":shape_util",
+ ":types",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "window_util",
+ srcs = ["window_util.cc"],
+ hdrs = ["window_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":types",
+ ":xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "reference_util",
+ srcs = ["reference_util.cc"],
+ hdrs = ["reference_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":array2d",
+ ":array3d",
+ ":array4d",
+ ":util",
+ ":window_util",
+ ":xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "reference_util_test",
+ srcs = ["reference_util_test.cc"],
+ deps = [
+ ":array2d",
+ ":array4d",
+ ":literal_util",
+ ":reference_util",
+ ":util",
+ ":xla_data_proto",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/README.md b/tensorflow/compiler/xla/README.md
new file mode 100644
index 0000000000..c93c39e180
--- /dev/null
+++ b/tensorflow/compiler/xla/README.md
@@ -0,0 +1 @@
+This is the home of XLA.
diff --git a/tensorflow/compiler/xla/array2d.cc b/tensorflow/compiler/xla/array2d.cc
new file mode 100644
index 0000000000..418587c1f7
--- /dev/null
+++ b/tensorflow/compiler/xla/array2d.cc
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+
+namespace xla {
+
+std::unique_ptr<Array2D<float>> MakeLinspaceArray2D(float from, float to,
+ int64 n1, int64 n2) {
+ auto array = MakeUnique<Array2D<float>>(n1, n2);
+ int64 count = n1 * n2;
+ float step = (count > 1) ? (to - from) / (count - 1) : 0.0f;
+ auto set = [&array, n1, n2](int64 index, float value) {
+ (*array)(index / n2, index % n2) = value;
+ };
+ for (int64 i = 0; i < count - 1; ++i) {
+ set(i, from + i * step);
+ }
+ set(count - 1, to);
+ return array;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h
new file mode 100644
index 0000000000..ceed573f1f
--- /dev/null
+++ b/tensorflow/compiler/xla/array2d.h
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_ARRAY2D_H_
+#define TENSORFLOW_COMPILER_XLA_ARRAY2D_H_
+
+#include <algorithm>
+#include <functional>
+#include <initializer_list>
+#include <iterator>
+#include <memory>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Simple 2D array structure.
+//
+// The data layout in major-to-minor order is: n1, n2.
+template <typename T>
+class Array2D {
+ public:
+ // Creates an empty array.
+ Array2D() : n1_(0), n2_(0) {}
+
+ // Creates an array of dimensions n1 x n2, uninitialized values.
+ Array2D(const int64 n1, const int64 n2) : n1_(n1), n2_(n2) {
+ values_.resize(n1 * n2);
+ }
+
+ // Creates an array of dimensions n1 x n2, initialized to value.
+ Array2D(const int64 n1, const int64 n2, const T value) : Array2D(n1, n2) {
+ Fill(value);
+ }
+
+ // Creates an array from the given nested initializer list. The outer
+ // initializer list is the first dimension; the inner is the second dimension.
+ // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
+ Array2D(std::initializer_list<std::initializer_list<T>> values)
+ : Array2D(values.size(), values.begin()->size()) {
+ int64 n1 = 0;
+ for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) {
+ int64 n2 = 0;
+ for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) {
+ (*this)(n1, n2) = *n2_it;
+ }
+ }
+ }
+
+ T& operator()(const int64 n1, const int64 n2) {
+ CHECK_LT(n1, n1_);
+ CHECK_LT(n2, n2_);
+ return values_[n1 * n2_ + n2];
+ }
+
+ const T& operator()(const int64 n1, const int64 n2) const {
+ CHECK_LT(n1, n1_);
+ CHECK_LT(n2, n2_);
+ return values_[n1 * n2_ + n2];
+ }
+
+ // Access to the array's dimensions. height() and width() provide the
+ // canonical interpretation of the array n1 x n2 having n1 rows of n2 columns
+ // each (height is number of rows; width is number of columns).
+ int64 n1() const { return n1_; }
+ int64 n2() const { return n2_; }
+ int64 height() const { return n1_; }
+ int64 width() const { return n2_; }
+ int64 num_elements() const { return values_.size(); }
+
+ // Low-level accessor for stuff like memcmp, handle with care. Returns pointer
+ // to the underlying storage of the array (similarly to std::vector::data()).
+ T* data() const { return const_cast<Array2D*>(this)->values_.data(); }
+
+ // Fills the array with the given value.
+ void Fill(const T& value) {
+ std::fill(values_.begin(), values_.end(), value);
+ }
+
+ // Applies f to all cells in this array, in row-major order.
+ void Each(std::function<void(int64, int64, T*)> f) {
+ for (int64 i0 = 0; i0 < n1(); ++i0) {
+ for (int64 i1 = 0; i1 < n2(); ++i1) {
+ f(i0, i1, &(*this)(i0, i1));
+ }
+ }
+ }
+
+ // Fills the array with a pattern of values of the form:
+ //
+ // (rowno << log2ceil(width) | colno) + start_value
+ //
+ // This makes it easy to see distinct row/column values in the array.
+ void FillUnique(T start_value = 0) {
+ for (int64 i0 = 0; i0 < n1(); ++i0) {
+ for (int64 i1 = 0; i1 < n2(); ++i1) {
+ (*this)(i0, i1) =
+ ((i0 << tensorflow::Log2Ceiling64(n2())) | i1) + start_value;
+ }
+ }
+ }
+
+ // Fills the array with random normal variables of deviation value.
+ void FillRandom(const T& value, const double mean = 0.0,
+ const int seed = 12345) {
+ std::mt19937 g(seed);
+ std::normal_distribution<double> distribution(mean,
+ static_cast<double>(value));
+ for (auto& v : values_) {
+ v = static_cast<T>(distribution(g));
+ }
+ }
+
+ // Returns a readable string representation of the array.
+ string ToString() const {
+ std::vector<string> pieces = {"["};
+ for (int64 row = 0; row < height(); ++row) {
+ pieces.push_back("[");
+ for (int64 col = 0; col < width(); ++col) {
+ pieces.push_back(tensorflow::strings::StrCat((*this)(row, col)));
+ pieces.push_back(", ");
+ }
+ pieces.pop_back();
+ pieces.push_back("]");
+ pieces.push_back(",\n ");
+ }
+ pieces.pop_back();
+ pieces.push_back("]");
+ return tensorflow::str_util::Join(pieces, "");
+ }
+
+ private:
+ int64 n1_;
+ int64 n2_;
+ std::vector<T> values_;
+};
+
+// Returns a linspace-populated Array2D in the range [from, to] (inclusive)
+// with dimensions n1 x n2.
+std::unique_ptr<Array2D<float>> MakeLinspaceArray2D(float from, float to,
+ int64 n1, int64 n2);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_ARRAY2D_H_
diff --git a/tensorflow/compiler/xla/array2d_test.cc b/tensorflow/compiler/xla/array2d_test.cc
new file mode 100644
index 0000000000..ac107b1c0d
--- /dev/null
+++ b/tensorflow/compiler/xla/array2d_test.cc
@@ -0,0 +1,132 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/array2d.h"
+
+#include <initializer_list>
+
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(Array2dTest, DefaultCtor) {
+ Array2D<int> empty;
+ EXPECT_EQ(empty.n1(), 0);
+ EXPECT_EQ(empty.n2(), 0);
+ EXPECT_EQ(empty.num_elements(), 0);
+}
+
+TEST(Array2dTest, UninitializedDimsCtor) {
+ Array2D<int> uninit(2, 3);
+ EXPECT_EQ(uninit.n1(), 2);
+ EXPECT_EQ(uninit.n2(), 3);
+ EXPECT_EQ(uninit.num_elements(), 6);
+}
+
+TEST(Array2dTest, FillCtor) {
+ Array2D<int> fullof7(2, 3, 7);
+
+ EXPECT_EQ(fullof7.n1(), 2);
+ EXPECT_EQ(fullof7.n2(), 3);
+
+ for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
+ for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) {
+ EXPECT_EQ(fullof7(n1, n2), 7);
+ }
+ }
+}
+
+TEST(Array2dTest, InitializerListCtor) {
+ Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
+
+ EXPECT_EQ(arr.n1(), 2);
+ EXPECT_EQ(arr.n2(), 3);
+
+ EXPECT_EQ(arr(0, 0), 1);
+ EXPECT_EQ(arr(0, 1), 2);
+ EXPECT_EQ(arr(0, 2), 3);
+ EXPECT_EQ(arr(1, 0), 4);
+ EXPECT_EQ(arr(1, 1), 5);
+ EXPECT_EQ(arr(1, 2), 6);
+}
+
+TEST(Array2dTest, Accessors) {
+ Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
+
+ EXPECT_EQ(arr.n1(), 2);
+ EXPECT_EQ(arr.n2(), 3);
+ EXPECT_EQ(arr.height(), 2);
+ EXPECT_EQ(arr.width(), 3);
+ EXPECT_EQ(arr.num_elements(), 6);
+}
+
+TEST(Array2dTest, IndexingReadWrite) {
+ Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
+
+ EXPECT_EQ(arr(1, 1), 5);
+ EXPECT_EQ(arr(1, 2), 6);
+ arr(1, 1) = 51;
+ arr(1, 2) = 61;
+ EXPECT_EQ(arr(1, 1), 51);
+ EXPECT_EQ(arr(1, 2), 61);
+}
+
+TEST(Array2dTest, Fill) {
+ Array2D<int> fullof7(2, 3, 7);
+ for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
+ for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) {
+ EXPECT_EQ(fullof7(n1, n2), 7);
+ }
+ }
+
+ fullof7.Fill(11);
+ for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
+ for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) {
+ EXPECT_EQ(fullof7(n1, n2), 11);
+ }
+ }
+}
+
+TEST(Array2dTest, DataPointer) {
+ Array2D<int> arr = {{1, 2, 3}, {4, 5, 6}};
+
+ EXPECT_EQ(arr.data()[0], 1);
+}
+
+TEST(Array2dTest, Linspace) {
+ auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
+
+ EXPECT_EQ(arr->n1(), 3);
+ EXPECT_EQ(arr->n2(), 2);
+
+ EXPECT_FLOAT_EQ((*arr)(0, 0), 1.0);
+ EXPECT_FLOAT_EQ((*arr)(0, 1), 1.5);
+ EXPECT_FLOAT_EQ((*arr)(1, 0), 2.0);
+ EXPECT_FLOAT_EQ((*arr)(1, 1), 2.5);
+ EXPECT_FLOAT_EQ((*arr)(2, 0), 3.0);
+ EXPECT_FLOAT_EQ((*arr)(2, 1), 3.5);
+}
+
+TEST(Array2dTest, Stringification) {
+ auto arr = MakeLinspaceArray2D(1.0, 3.5, 3, 2);
+ const string expected = R"([[1, 1.5],
+ [2, 2.5],
+ [3, 3.5]])";
+ EXPECT_EQ(expected, arr->ToString());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h
new file mode 100644
index 0000000000..46bc1a6392
--- /dev/null
+++ b/tensorflow/compiler/xla/array3d.h
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_ARRAY3D_H_
+#define TENSORFLOW_COMPILER_XLA_ARRAY3D_H_
+
+#include <algorithm>
+#include <functional>
+#include <initializer_list>
+#include <iterator>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Simple 3D array structure.
+//
+// The data layout in major-to-minor order is: n1, n2, n3.
+template <typename T>
+class Array3D {
+ public:
+ // Creates an array of dimensions n1 x n2 x n3, uninitialized values.
+ Array3D(const int64 n1, const int64 n2, const int64 n3)
+ : n1_(n1), n2_(n2), n3_(n3) {
+ values_.resize(n1 * n2 * n3);
+ }
+
+ // Creates an array of dimensions n1 x n2 x n3, initialized to value.
+ Array3D(const int64 n1, const int64 n2, const int64 n3, const T value)
+ : Array3D(n1, n2, n3) {
+ Fill(value);
+ }
+
+ // Creates an array from the given nested initializer list. The outer
+ // initializer list is the first dimension, and so on.
+ //
+ // For example {{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
+ // {{9, 10}, {11, 12}, {13, 14}, {15, 16}},
+ // {{17, 18}, {19, 20}, {21, 22}, {23, 24}}}
+ // results in an array with n1=3, n2=4, n3=2.
+ Array3D(std::initializer_list<std::initializer_list<std::initializer_list<T>>>
+ values)
+ : Array3D(values.size(), values.begin()->size(),
+ values.begin()->begin()->size()) {
+ int64 n1 = 0;
+ for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) {
+ int64 n2 = 0;
+ for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) {
+ int64 n3 = 0;
+ for (auto n3_it = n2_it->begin(); n3_it != n2_it->end();
+ ++n3_it, ++n3) {
+ (*this)(n1, n2, n3) = *n3_it;
+ }
+ }
+ }
+ }
+
+ T& operator()(const int64 n1, const int64 n2, const int64 n3) {
+ CHECK_LT(n1, n1_);
+ CHECK_LT(n2, n2_);
+ CHECK_LT(n3, n3_);
+ return values_[n1 * n2_ * n3_ + n2 * n3_ + n3];
+ }
+
+ const T& operator()(const int64 n1, const int64 n2, const int64 n3) const {
+ CHECK_LT(n1, n1_);
+ CHECK_LT(n2, n2_);
+ CHECK_LT(n3, n3_);
+ return values_[n1 * n2_ * n3_ + n2 * n3_ + n3];
+ }
+
+ // Access to the array's dimensions.
+ int64 n1() const { return n1_; }
+ int64 n2() const { return n2_; }
+ int64 n3() const { return n3_; }
+ int64 num_elements() const { return values_.size(); }
+
+ // Fills the array with the given value.
+ void Fill(const T& value) {
+ std::fill(values_.begin(), values_.end(), value);
+ }
+
+ // Fills the array with sequentially increasing values.
+ void FillIota(const T& value) {
+ std::iota(values_.begin(), values_.end(), value);
+ }
+
+ // Fills the array with random normal values with a mean of 0 and standard
+ // deviation of value.
+ void FillRandom(const T& value, const double mean = 0.0,
+ const int seed = 12345) {
+ std::mt19937 g(seed);
+ std::normal_distribution<double> distribution(mean,
+ static_cast<double>(value));
+ for (auto& v : values_) {
+ v = static_cast<T>(distribution(g));
+ }
+ }
+
+ private:
+ int64 n1_;
+ int64 n2_;
+ int64 n3_;
+ std::vector<T> values_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_ARRAY3D_H_
diff --git a/tensorflow/compiler/xla/array3d_test.cc b/tensorflow/compiler/xla/array3d_test.cc
new file mode 100644
index 0000000000..fa4435dfc4
--- /dev/null
+++ b/tensorflow/compiler/xla/array3d_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/array3d.h"
+
+#include <initializer_list>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(Array3dTest, UninitializedDimsCtor) {
+ Array3D<int> uninit(2, 3, 4);
+ EXPECT_EQ(uninit.n1(), 2);
+ EXPECT_EQ(uninit.n2(), 3);
+ EXPECT_EQ(uninit.n3(), 4);
+ EXPECT_EQ(uninit.num_elements(), 24);
+}
+
+TEST(Array3dTest, FillCtor) {
+ Array3D<int> fullof7(2, 3, 4, 7);
+
+ EXPECT_EQ(fullof7.n1(), 2);
+ EXPECT_EQ(fullof7.n2(), 3);
+ EXPECT_EQ(fullof7.n3(), 4);
+
+ for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
+ for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) {
+ for (int64 n3 = 0; n3 < fullof7.n3(); ++n3) {
+ EXPECT_EQ(fullof7(n1, n2, n3), 7);
+ }
+ }
+ }
+}
+
+TEST(Array3dTest, InitializerListCtor) {
+ Array3D<int> arr = {{{1, 2}, {3, 4}, {5, 6}, {7, 8}},
+ {{9, 10}, {11, 12}, {13, 14}, {15, 16}},
+ {{17, 18}, {19, 20}, {21, 22}, {23, 24}}};
+
+ EXPECT_EQ(arr.n1(), 3);
+ EXPECT_EQ(arr.n2(), 4);
+ EXPECT_EQ(arr.n3(), 2);
+ EXPECT_EQ(arr.num_elements(), 24);
+
+ EXPECT_EQ(arr(0, 0, 0), 1);
+ EXPECT_EQ(arr(0, 0, 1), 2);
+ EXPECT_EQ(arr(0, 1, 0), 3);
+ EXPECT_EQ(arr(0, 3, 1), 8);
+ EXPECT_EQ(arr(1, 0, 0), 9);
+ EXPECT_EQ(arr(1, 1, 1), 12);
+ EXPECT_EQ(arr(2, 0, 0), 17);
+ EXPECT_EQ(arr(2, 1, 1), 20);
+ EXPECT_EQ(arr(2, 2, 0), 21);
+ EXPECT_EQ(arr(2, 3, 1), 24);
+}
+
+TEST(Array3dTest, Fill) {
+ Array3D<int> fullof7(2, 3, 4, 7);
+ for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
+ for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) {
+ for (int64 n3 = 0; n3 < fullof7.n3(); ++n3) {
+ EXPECT_EQ(fullof7(n1, n2, n3), 7);
+ }
+ }
+ }
+
+ fullof7.Fill(11);
+ for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) {
+ for (int64 n2 = 0; n2 < fullof7.n2(); ++n2) {
+ for (int64 n3 = 0; n3 < fullof7.n3(); ++n3) {
+ EXPECT_EQ(fullof7(n1, n2, n3), 11);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h
new file mode 100644
index 0000000000..db51a57cf2
--- /dev/null
+++ b/tensorflow/compiler/xla/array4d.h
@@ -0,0 +1,272 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_ARRAY4D_H_
+#define TENSORFLOW_COMPILER_XLA_ARRAY4D_H_
+
+#include <algorithm>
+#include <functional>
+#include <initializer_list>
+#include <iterator>
+#include <numeric>
+#include <random>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Simple 4D array structure, similar in form to Array2D, for use primarily in
+// testing and describing to XLA APIs values in the 4D array structures used
+// in convolutions.
+//
+// The data layout is, in order from major to minor:
+//
+// First dimension: plane, batch, n1
+// Second dimension: depth, feature, z, n2
+// Third dimension: height, y, n3
+// Fourth dimension: width, x, n4
+//
+// These dimensions are referred to by various names, so that is why
+// more than one name is given above. See operator() for the exact
+// calculation of 1d indices from 4d indices.
+template <typename T>
+class Array4D {
+ public:
+ // Creates a 4D array, unitialized values.
+ Array4D(int64 planes, int64 depth, int64 height, int64 width)
+ : planes_(planes), depth_(depth), height_(height), width_(width) {
+ values_.resize(planes * depth * height * width);
+ }
+
+ // Creates a 4D array, initalized to value.
+ Array4D(int64 planes, int64 depth, int64 height, int64 width, T value)
+ : Array4D(planes, depth, height, width) {
+ Fill(value);
+ }
+
+ // Creates a 4D array, filled with values.
+ //
+ // We need to set a default type for Container so that code like
+ // Array4D(1, 1, 1, 1, {1}) will work. The template cannot infer the
+ // initializer_list type in that case without this default.
+ template <typename Container = std::initializer_list<T>>
+ Array4D(int64 planes, int64 depth, int64 height, int64 width,
+ const Container& values)
+ : Array4D(planes, depth, height, width) {
+ SetValues(values);
+ }
+
+ // Construct an Array4D with the given nested initializer list.
+ Array4D(std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<T>>>>
+ values)
+ : Array4D(values.size(), values.begin()->size(),
+ values.begin()->begin()->size(),
+ values.begin()->begin()->begin()->size()) {
+ int64 plane = 0;
+ for (const auto values_in_plane : values) {
+ DCHECK_EQ(values_in_plane.size(), depth_);
+ int64 depth = 0;
+ for (const auto values_in_depth : values_in_plane) {
+ DCHECK_EQ(values_in_depth.size(), height_);
+ int64 height = 0;
+ for (const auto values_in_height : values_in_depth) {
+ DCHECK_EQ(values_in_height.size(), width_);
+ int64 width = 0;
+ for (const auto element_value : values_in_height) {
+ (*this)(plane, depth, height, width) = element_value;
+ ++width;
+ }
+ ++height;
+ }
+ ++depth;
+ }
+ ++plane;
+ }
+ }
+
+ T& operator()(int64 plane, int64 depth, int64 height, int64 width) {
+ CHECK_LT(plane, planes_);
+ CHECK_LT(depth, depth_);
+ CHECK_LT(height, height_);
+ CHECK_LT(width, width_);
+ return values_[plane * (depth_ * height_ * width_) +
+ depth * (height_ * width_) + height * (width_) + width];
+ }
+ const T& operator()(int64 plane, int64 depth, int64 height,
+ int64 width) const {
+ return const_cast<Array4D*>(this)->operator()(plane, depth, height, width);
+ }
+
+ int64 width() const { return width_; }
+ int64 height() const { return height_; }
+ int64 depth() const { return depth_; }
+ int64 planes() const { return planes_; }
+
+ // Numerically-named aliases for the various dimensions. This matches the
+ // dimension names used in array3d.
+ int64 n4() const { return width_; }
+ int64 n3() const { return height_; }
+ int64 n2() const { return depth_; }
+ int64 n1() const { return planes_; }
+ int64 num_elements() const { return values_.size(); }
+
+ // Sets all the values in the array to values.
+ template <typename Container = std::initializer_list<T>>
+ void SetValues(const Container& container) {
+ CHECK_EQ(std::distance(std::begin(container), std::end(container)),
+ num_elements());
+ values_.assign(std::begin(container), std::end(container));
+ }
+
+ // Fills the array with the given value.
+ void Fill(const T& value) {
+ std::fill(values_.begin(), values_.end(), value);
+ }
+
+ // Fills the array with iota.
+ void FillIota(const T& value) {
+ std::iota(values_.begin(), values_.end(), value);
+ }
+
+ // Fills the array with random variable with a deviation of value and a mean
+ // of mean.
+ void FillRandom(const T& value, const double mean = 0.0,
+ const int seed = 12345) {
+ std::mt19937 g(seed);
+ std::normal_distribution<double> distribution(mean,
+ static_cast<double>(value));
+ for (auto& v : values_) {
+ v = static_cast<T>(distribution(g));
+ }
+ }
+
+ // Fills values with the sequence i*multiplier for i=0,1,...
+ void FillWithMultiples(float multiplier) {
+ for (int64 i = 0; i < num_elements(); ++i) {
+ values_[i] = i * multiplier;
+ }
+ }
+
+ // Invokes a callback with the (indices, value_ptr) for each cell in the 4D
+ // array.
+ void Each(std::function<void(tensorflow::gtl::ArraySlice<int64>, T*)> f) {
+ for (int64 plane = 0; plane < planes(); ++plane) {
+ for (int64 depth = 0; depth < this->depth(); ++depth) {
+ for (int64 height = 0; height < this->height(); ++height) {
+ for (int64 width = 0; width < this->width(); ++width) {
+ auto& value = (*this)(plane, depth, height, width);
+ f({plane, depth, height, width}, &value);
+ }
+ }
+ }
+ }
+ }
+
+ // Fills all of the {p,z} with the array provided, which specifies {y,x}.
+ void FillWithYX(const Array2D<T>& value) {
+ CHECK_EQ(value.height(), height());
+ CHECK_EQ(value.width(), width());
+ for (int64 plane = 0; plane < planes(); ++plane) {
+ for (int64 depth = 0; depth < this->depth(); ++depth) {
+ for (int64 height = 0; height < this->height(); ++height) {
+ for (int64 width = 0; width < this->width(); ++width) {
+ (*this)(plane, depth, height, width) = value(height, width);
+ }
+ }
+ }
+ }
+ }
+
+ // Fills all of the {x,y} with the array provided, which specifies {p,z}.
+ void FillWithPZ(const Array2D<T>& value) {
+ CHECK_EQ(value.height(), planes());
+ CHECK_EQ(value.width(), depth());
+ for (int64 height = 0; height < this->height(); ++height) {
+ for (int64 width = 0; width < this->width(); ++width) {
+ for (int64 plane = 0; plane < planes(); ++plane) {
+ for (int64 depth = 0; depth < this->depth(); ++depth) {
+ (*this)(plane, depth, height, width) = value(plane, depth);
+ }
+ }
+ }
+ }
+ }
+
+ // Fills each of the minor-dim matrices with a number designating which minor
+ // dim matrix is enclosed by the shape.
+ void FillWithMinorDimNum() {
+ LOG(INFO) << "width: " << this->width();
+ LOG(INFO) << "height: " << this->height();
+ LOG(INFO) << "depth: " << this->depth();
+ LOG(INFO) << "planes: " << this->planes();
+ for (int64 height = 0; height < this->height(); ++height) {
+ for (int64 width = 0; width < this->width(); ++width) {
+ for (int64 plane = 0; plane < planes(); ++plane) {
+ for (int64 depth = 0; depth < this->depth(); ++depth) {
+ float this_val = plane * this->depth() + depth;
+ (*this)(plane, depth, height, width) = this_val;
+ }
+ }
+ }
+ }
+ }
+
+ // Returns a string representation of the 4D array suitable for debugging.
+ string ToString() const {
+ std::vector<string> pieces = {
+ tensorflow::strings::Printf("p=%lld,z=%lld,y=%lld,x=%lld {\n", planes(),
+ depth(), height(), width())};
+ for (int64 plane = 0; plane < planes_; ++plane) {
+ pieces.push_back(" {\n");
+ for (int64 depth = 0; depth < depth_; ++depth) {
+ pieces.push_back(" {\n");
+ for (int64 height = 0; height < height_; ++height) {
+ pieces.push_back(" {");
+ for (int64 width = 0; width < width_; ++width) {
+ pieces.push_back(tensorflow::strings::StrCat(
+ (*this)(plane, depth, height, width), ", "));
+ }
+ pieces.push_back("},\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back("}");
+ return tensorflow::str_util::Join(pieces, "");
+ }
+
+ private:
+ int64 planes_;
+ int64 depth_;
+ int64 height_;
+ int64 width_;
+ std::vector<T> values_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_ARRAY4D_H_
diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc
new file mode 100644
index 0000000000..72ada467e5
--- /dev/null
+++ b/tensorflow/compiler/xla/array4d_test.cc
@@ -0,0 +1,180 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/array4d.h"
+
+#include <initializer_list>
+#include <numeric>
+
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+// Given an Array4D and a 4-tuple index, computes the linear index into the
+// array idx represents.
+template <typename T>
+int64 Array4DLinearIndex(const Array4D<T>& arr,
+ tensorflow::gtl::ArraySlice<int64> idx) {
+ EXPECT_EQ(4, idx.size());
+ return (idx[3] + idx[2] * arr.n4() + idx[1] * arr.n3() * arr.n4() +
+ idx[0] * arr.n2() * arr.n3() * arr.n4());
+}
+
+TEST(Array4dTest, UninitializedDimsCtor) {
+ Array4D<int> empty(2, 3, 4, 5);
+ EXPECT_EQ(empty.n1(), 2);
+ EXPECT_EQ(empty.n2(), 3);
+ EXPECT_EQ(empty.n3(), 4);
+ EXPECT_EQ(empty.n4(), 5);
+ EXPECT_EQ(empty.num_elements(), 120);
+}
+
+TEST(Array4dTest, FillCtor) {
+ Array4D<int> fullof7(2, 3, 4, 5, 7);
+
+ EXPECT_EQ(fullof7.n1(), 2);
+ EXPECT_EQ(fullof7.n2(), 3);
+ EXPECT_EQ(fullof7.n3(), 4);
+ EXPECT_EQ(fullof7.n4(), 5);
+
+ fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
+ EXPECT_EQ(*cell, 7);
+ });
+}
+
+TEST(Array4dTest, ContainerCtor) {
+ // Fill an Array4D with a linear vector of [0..119] according to the default
+ // row-major ordering.
+ std::vector<int> filler(120);
+ std::iota(filler.begin(), filler.end(), 0);
+
+ Array4D<int> arr(2, 3, 4, 5, filler);
+
+ EXPECT_EQ(arr.n1(), 2);
+ EXPECT_EQ(arr.n2(), 3);
+ EXPECT_EQ(arr.n3(), 4);
+ EXPECT_EQ(arr.n4(), 5);
+
+ arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
+ EXPECT_EQ(*cell, Array4DLinearIndex(arr, idx));
+ });
+}
+
+TEST(Array3dTest, InitializerListCtor) {
+ Array4D<int> arr = {{{{1}, {2}}, {{3}, {4}}, {{5}, {6}}, {{7}, {8}}},
+ {{{9}, {10}}, {{11}, {12}}, {{13}, {14}}, {{15}, {16}}},
+ {{{17}, {18}}, {{19}, {20}}, {{21}, {22}}, {{23}, {24}}}};
+
+ EXPECT_EQ(arr.n1(), 3);
+ EXPECT_EQ(arr.n2(), 4);
+ EXPECT_EQ(arr.n3(), 2);
+ EXPECT_EQ(arr.n4(), 1);
+ EXPECT_EQ(arr.num_elements(), 24);
+
+ EXPECT_EQ(arr(0, 0, 0, 0), 1);
+ EXPECT_EQ(arr(0, 0, 1, 0), 2);
+ EXPECT_EQ(arr(0, 1, 0, 0), 3);
+ EXPECT_EQ(arr(0, 3, 1, 0), 8);
+ EXPECT_EQ(arr(1, 0, 0, 0), 9);
+ EXPECT_EQ(arr(1, 1, 1, 0), 12);
+ EXPECT_EQ(arr(2, 0, 0, 0), 17);
+ EXPECT_EQ(arr(2, 1, 1, 0), 20);
+ EXPECT_EQ(arr(2, 2, 0, 0), 21);
+ EXPECT_EQ(arr(2, 3, 1, 0), 24);
+}
+
+TEST(Array4dTest, Fill) {
+ Array4D<int> fullof7(2, 3, 4, 5, 7);
+ fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
+ EXPECT_EQ(*cell, 7);
+ });
+
+ fullof7.Fill(11);
+ fullof7.Each([](tensorflow::gtl::ArraySlice<int64> idx, int* cell) {
+ EXPECT_EQ(*cell, 11);
+ });
+}
+
+TEST(Array4dTest, FillWithMultiples) {
+ Array4D<float> arr(2, 3, 4, 5);
+ arr.FillWithMultiples(2.0f);
+
+ arr.Each([&arr](tensorflow::gtl::ArraySlice<int64> idx, float* cell) {
+ EXPECT_EQ(*cell, 2.0f * Array4DLinearIndex(arr, idx));
+ });
+}
+
+TEST(Array4dTest, FillRasterDimensionDepthOne) {
+ Array4D<float> array(1, 1, 128, 128);
+ Array2D<float> raster(128, 128);
+ for (int row = 0; row < 128; ++row) {
+ for (int col = 0; col < 128; ++col) {
+ raster(row, col) = row * 1000.0 + col;
+ }
+ }
+
+ array.FillWithYX(raster);
+
+ VLOG(1) << array.ToString();
+
+ EXPECT_FLOAT_EQ(raster(0, 0), array(0, 0, 0, 0));
+ EXPECT_FLOAT_EQ(raster(0, 1), array(0, 0, 0, 1));
+ EXPECT_FLOAT_EQ(raster(1, 0), array(0, 0, 1, 0));
+ EXPECT_FLOAT_EQ(raster(1, 1), array(0, 0, 1, 1));
+ EXPECT_FLOAT_EQ(raster(2, 0), array(0, 0, 2, 0));
+ EXPECT_FLOAT_EQ(raster(127, 127), array(0, 0, 127, 127));
+
+ EXPECT_FLOAT_EQ(0, array(0, 0, 0, 0));
+ EXPECT_FLOAT_EQ(1, array(0, 0, 0, 1));
+ EXPECT_FLOAT_EQ(2, array(0, 0, 0, 2));
+
+ EXPECT_FLOAT_EQ(1001, array(0, 0, 1, 1));
+ EXPECT_FLOAT_EQ(2001, array(0, 0, 2, 1));
+ EXPECT_FLOAT_EQ(127000, array(0, 0, 127, 0));
+ EXPECT_FLOAT_EQ(127127, array(0, 0, 127, 127));
+}
+
+TEST(Array4dTest, FillWithPzTestDepthOne) {
+ Array2D<float> matrix(3, 2);
+ std::initializer_list<std::initializer_list<float>> values = {
+ {-3.f, -0.1f}, {0.f, -0.1f}, {3.f, 0.2f},
+ };
+ int rowno = 0;
+ for (auto row : values) {
+ int colno = 0;
+ for (float f : row) {
+ matrix(rowno, colno) = f;
+ colno++;
+ }
+ rowno++;
+ }
+
+ Array4D<float> actual(3, 2, 1, 1);
+ actual.FillWithPZ(matrix);
+
+ EXPECT_FLOAT_EQ(-3, actual(0, 0, 0, 0));
+ EXPECT_FLOAT_EQ(-0.1, actual(0, 1, 0, 0));
+
+ EXPECT_FLOAT_EQ(0, actual(1, 0, 0, 0));
+ EXPECT_FLOAT_EQ(-0.1, actual(1, 1, 0, 0));
+
+ EXPECT_FLOAT_EQ(3, actual(2, 0, 0, 0));
+ EXPECT_FLOAT_EQ(0.2, actual(2, 1, 0, 0));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
new file mode 100644
index 0000000000..3e9dfe2a92
--- /dev/null
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -0,0 +1,175 @@
+# Description:
+# XLA client libraries.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [":friends"])
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "global_data",
+ srcs = ["global_data.cc"],
+ hdrs = ["global_data.h"],
+ deps = [
+ "//tensorflow/compiler/xla:service_interface",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "padding",
+ srcs = ["padding.cc"],
+ hdrs = ["padding.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "padding_test",
+ srcs = ["padding_test.cc"],
+ deps = [
+ ":padding",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "client",
+ srcs = ["client.cc"],
+ hdrs = ["client.h"],
+ deps = [
+ ":computation",
+ ":global_data",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:service_interface",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "local_client",
+ srcs = ["local_client.cc"],
+ hdrs = ["local_client.h"],
+ deps = [
+ ":client",
+ ":computation",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:support",
+ ],
+)
+
+# This target is used to instantiate the XLA service in-process and create
+# a client for it.
+cc_library(
+ name = "client_library",
+ srcs = ["client_library.cc"],
+ hdrs = ["client_library.h"],
+ deps = [
+ ":local_client",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "computation",
+ srcs = ["computation.cc"],
+ hdrs = ["computation.h"],
+ deps = [
+ "//tensorflow/compiler/xla:service_interface",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "computation_builder",
+ srcs = ["computation_builder.cc"],
+ hdrs = ["computation_builder.h"],
+ deps = [
+ ":client",
+ ":computation",
+ ":global_data",
+ ":padding",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
new file mode 100644
index 0000000000..f70ab294c0
--- /dev/null
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -0,0 +1,479 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/client.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+Client::Client(ServiceInterface* stub) : stub_(stub) {}
+
+Client::~Client() = default;
+
+StatusOr<std::unique_ptr<Literal>> Client::Transfer(
+ const GlobalData& data, const Shape* shape_with_layout) {
+ TransferToClientRequest request;
+ *request.mutable_data() = data.handle();
+ if (shape_with_layout != nullptr) {
+ *request.mutable_shape_with_layout() = *shape_with_layout;
+ }
+ TransferToClientResponse response;
+
+ VLOG(1) << "making transfer request";
+ VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
+ Status s = stub_->TransferToClient(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";
+
+ if (!response.has_literal()) {
+ return FailedPrecondition(
+ "server provided response without a literal in "
+ "TransferToClient request");
+ }
+
+ return WrapUnique(response.release_literal());
+}
+
+Status Client::TransferInProcess(const GlobalData& data, void* destination) {
+ TransferToClientInProcessRequest request;
+ *request.mutable_data() = data.handle();
+ request.set_buffer(reinterpret_cast<uint64>(destination));
+ TransferToClientInProcessResponse response;
+
+ VLOG(1) << "making transfer in-process request";
+ VLOG(3) << "TransferToClientInProcessRequest: {" << request.DebugString()
+ << "}";
+ Status s = stub_->TransferToClientInProcess(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferToClientInProcessResponse: {" << response.DebugString()
+ << "}";
+ return Status::OK();
+}
+
+StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
+ const Literal& literal, const DeviceHandle* device_handle) {
+ TransferToServerRequest request;
+ *request.mutable_literal() = literal;
+ if (device_handle) {
+ *request.mutable_device_handle() = *device_handle;
+ }
+ TransferToServerResponse response;
+
+ VLOG(1) << "making transfer to server request";
+ VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
+ Status s = stub_->TransferToServer(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";
+
+ if (!response.has_data()) {
+ return FailedPrecondition(
+ "server provided response without a data handle in "
+ "TransferToServer request");
+ }
+
+ return MakeUnique<GlobalData>(stub_, response.data());
+}
+
+Status Client::TransferToInfeed(const Literal& literal, int64 replica_id,
+ const DeviceHandle* device_handle) {
+ TransferToInfeedRequest request;
+ *request.mutable_literal() = literal;
+ if (device_handle) {
+ *request.mutable_device_handle() = *device_handle;
+ }
+ request.set_replica_id(replica_id);
+ TransferToInfeedResponse response;
+
+ VLOG(1) << "making transfer to infeed request";
+ VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
+ Status s = stub_->TransferToInfeed(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
+ return Status::OK();
+}
+
+Status Client::ResetDevice() {
+ ResetDeviceRequest request;
+ ResetDeviceResponse response;
+
+ VLOG(1) << "making reset device request";
+ VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
+ Status s = stub_->ResetDevice(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
+ return Status::OK();
+}
+
+StatusOr<std::unique_ptr<Literal>> Client::ExecuteAndTransfer(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout, ExecutionProfile* execution_profile,
+ uint64 seed) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
+ Execute(computation, arguments, shape_with_output_layout,
+ execution_profile, seed));
+ return Transfer(*data, shape_with_output_layout);
+}
+
+StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServerInProcess(
+ const Shape& shape, const void* buffer) {
+ TransferToServerInProcessRequest request;
+ request.set_buffer(reinterpret_cast<uint64>(buffer));
+ *request.mutable_shape() = shape;
+ TransferToServerInProcessResponse response;
+
+ VLOG(1) << "making transfer to server in-process request";
+ VLOG(3) << "TransferToServerInProcessRequest: {" << request.DebugString()
+ << "}";
+ Status s = stub_->TransferToServerInProcess(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ VLOG(3) << "TransferToServerInProcessResponse: {" << response.DebugString()
+ << "}";
+
+ if (!response.has_data()) {
+ return FailedPrecondition(
+ "server provided response without a data handle in "
+ "TransferToServerInProcess request");
+ }
+
+ return MakeUnique<GlobalData>(stub_, response.data());
+}
+
+StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
+ LoadComputationSnapshotRequest request;
+ *request.mutable_module() = module;
+ LoadComputationSnapshotResponse response;
+
+ Status s = stub_->LoadComputationSnapshot(&request, &response);
+ if (!s.ok()) {
+ return s;
+ }
+
+ VLOG(1) << "load snapshot response: " << response.ShortDebugString();
+ return Computation(stub_, response.computation());
+}
+
+StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout, ExecutionProfile* execution_profile,
+ uint64 seed) {
+ ExecuteRequest request;
+ *request.mutable_computation() = computation.handle();
+ request.set_seed(seed);
+ for (GlobalData* argument : arguments) {
+ *request.add_arguments() = argument->handle();
+ }
+ if (shape_with_output_layout != nullptr) {
+ *request.mutable_shape_with_output_layout() = *shape_with_output_layout;
+ }
+
+ ExecuteResponse response;
+ VLOG(1) << "making execute request: " << request.ShortDebugString();
+ Status s = stub_->Execute(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (execution_profile != nullptr) {
+ *execution_profile = response.profile();
+ if (VLOG_IS_ON(1)) {
+ TF_ASSIGN_OR_RETURN(
+ auto execution_stats,
+ ExecutionStatsAsString(computation, response.profile()));
+ VLOG(1) << execution_stats;
+ }
+ }
+
+ return MakeUnique<GlobalData>(stub_, response.output());
+}
+
+StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
+ tensorflow::gtl::ArraySlice<ComputationInstance> computations) {
+ ExecuteParallelRequest request;
+
+ for (const ComputationInstance& computation : computations) {
+ ExecuteRequest single_request;
+ *single_request.mutable_computation() = computation.computation.handle();
+ for (GlobalData* argument : computation.arguments) {
+ *single_request.add_arguments() = argument->handle();
+ }
+ if (computation.device_handle != nullptr) {
+ *single_request.mutable_device_handle() = *computation.device_handle;
+ }
+ if (computation.shape_with_output_layout != nullptr) {
+ *single_request.mutable_shape_with_output_layout() =
+ *computation.shape_with_output_layout;
+ }
+ single_request.set_seed(computation.seed);
+ *request.add_requests() = single_request;
+ }
+
+ ExecuteParallelResponse response;
+ VLOG(1) << "making execute-parallel request: " << request.ShortDebugString();
+ tensorflow::Status s = stub_->ExecuteParallel(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::vector<std::unique_ptr<GlobalData>> outputs;
+ for (int64 i = 0; i < computations.size(); ++i) {
+ outputs.push_back(
+ MakeUnique<GlobalData>(stub_, response.responses(i).output()));
+ if (computations[i].execution_profile != nullptr) {
+ *computations[i].execution_profile = response.responses(i).profile();
+ }
+ }
+
+ return std::move(outputs);
+}
+
+StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
+ int64 device_count) {
+ if (device_count < 1) {
+ return InvalidArgument("device_count must be greater than 0");
+ }
+ GetDeviceHandlesRequest request;
+ request.set_device_count(device_count);
+
+ GetDeviceHandlesResponse response;
+ VLOG(1) << "making get device request: " << request.ShortDebugString();
+ tensorflow::Status s = stub_->GetDeviceHandles(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::vector<DeviceHandle> device_handles;
+ for (const DeviceHandle& device_handle : response.device_handles()) {
+ device_handles.push_back(device_handle);
+ }
+
+ return device_handles;
+}
+
+StatusOr<ExecutionHandle> Client::ExecuteAsync(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout, uint64 seed) {
+ ExecuteAsyncRequest request;
+ *request.mutable_computation() = computation.handle();
+ request.set_seed(seed);
+ for (GlobalData* argument : arguments) {
+ *request.add_arguments() = argument->handle();
+ }
+ if (shape_with_output_layout != nullptr) {
+ *request.mutable_shape_with_output_layout() = *shape_with_output_layout;
+ }
+
+ ExecuteAsyncResponse response;
+ VLOG(1) << "making execute async request: " << request.ShortDebugString();
+ Status s = stub_->ExecuteAsync(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ return response.execution();
+}
+
+StatusOr<std::unique_ptr<GlobalData>> Client::WaitForExecution(
+ const Computation& computation, const ExecutionHandle& execution,
+ ExecutionProfile* execution_profile) {
+ WaitForExecutionRequest request;
+ *request.mutable_execution() = execution;
+
+ WaitForExecutionResponse response;
+ VLOG(1) << "making wait-for-execute request: " << request.ShortDebugString();
+ Status s = stub_->WaitForExecution(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ if (execution_profile != nullptr) {
+ *execution_profile = response.profile();
+ if (VLOG_IS_ON(1)) {
+ TF_ASSIGN_OR_RETURN(
+ auto execution_stats,
+ ExecutionStatsAsString(computation, response.profile()));
+ VLOG(1) << execution_stats;
+ }
+ }
+
+ return MakeUnique<GlobalData>(stub_, response.output());
+}
+
+Status Client::Unregister(const GlobalData& data) {
+ UnregisterRequest request;
+ *request.mutable_data() = data.handle();
+ UnregisterResponse response;
+
+ VLOG(1) << "making unregister request";
+ Status s = stub_->Unregister(&request, &response);
+ VLOG(1) << "done with request";
+
+ return s;
+}
+
+StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
+ const GlobalData& data) {
+ DeconstructTupleRequest request;
+ *request.mutable_tuple_handle() = data.handle();
+ DeconstructTupleResponse response;
+
+ VLOG(1) << "making DestructTuple request";
+ Status s = stub_->DeconstructTuple(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ std::vector<std::unique_ptr<GlobalData>> handles;
+ for (auto& handle : response.element_handles()) {
+ handles.push_back(MakeUnique<GlobalData>(stub_, handle));
+ }
+ return std::move(handles);
+}
+
+StatusOr<ComputationStats> Client::GetComputationStats(
+ const Computation& computation) const {
+ ComputationStatsRequest request;
+ *request.mutable_computation() = computation.handle();
+ ComputationStatsResponse response;
+
+ VLOG(1) << "making computation stats request";
+ Status s = stub_->GetComputationStats(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+ CHECK(response.has_stats());
+ return response.stats();
+}
+
+StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
+ const Computation& computation) {
+ GetComputationShapeRequest request;
+ *request.mutable_computation() = computation.handle();
+ GetComputationShapeResponse response;
+
+ VLOG(1) << "making get-computation-shape request";
+ Status s = stub_->GetComputationShape(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ return WrapUnique(response.release_program_shape());
+}
+
+StatusOr<Shape> Client::GetShape(const GlobalData& data) {
+ GetShapeRequest request;
+ *request.mutable_data() = data.handle();
+ GetShapeResponse response;
+
+ VLOG(1) << "making get shape request";
+ Status s = stub_->GetShape(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ return response.shape();
+}
+
+StatusOr<string> Client::ExecutionStatsAsString(
+ const Computation& computation, const ExecutionProfile& profile) {
+ TF_ASSIGN_OR_RETURN(auto computation_stats, GetComputationStats(computation));
+ int64 total_flops =
+ computation_stats.flop_count() + computation_stats.transcendental_count();
+ if (profile.compute_time_ns() > 0) {
+ int64 nanoseconds = profile.compute_time_ns();
+ int64 cycle_count = profile.compute_cycle_count();
+ double gflops = total_flops / nanoseconds;
+ return tensorflow::strings::StrCat(
+ "[Execution Statistics] flop count: ", computation_stats.flop_count(),
+ ", transcendental count: ", computation_stats.transcendental_count(),
+ ", compute execution time: ", nanoseconds, " nsec",
+ ", compute cycles: ", cycle_count, ", performance: ", gflops,
+ "gflop/s");
+ }
+ return string("[Execution Statistics] not available.");
+}
+
+StatusOr<ChannelHandle> Client::CreateChannelHandle() {
+ CreateChannelHandleRequest request;
+ CreateChannelHandleResponse response;
+
+ VLOG(1) << "making create channel handle request";
+ Status s = stub_->CreateChannelHandle(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ return s;
+ }
+
+ return response.channel();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
new file mode 100644
index 0000000000..b5beeb36b1
--- /dev/null
+++ b/tensorflow/compiler/xla/client/client.h
@@ -0,0 +1,202 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// XLA service's client object -- wraps the service with convenience and
+// lifetime-oriented methods.
+class Client {
+ public:
+ explicit Client(ServiceInterface* stub);
+ virtual ~Client();
+
+ // Executes the computation with the given arguments and returns the global
+ // data that was produced from the execution.
+ // * If shape_with_output_layout is not nullptr this points to a shape with a
+ // layout to use as a hint when storing the output of the computation.
+ // Subsequent transfers of this output array to the client may be faster
+ // when using this layout.
+ // * If execution_profile is not nullptr then the pointed-to ExecutionProfile
+ // will be filled with profile data from the execution.
+ // * If seed is set then that seed is used for the graph execution.
+ StatusOr<std::unique_ptr<GlobalData>> Execute(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout = nullptr,
+ ExecutionProfile* execution_profile = nullptr, uint64 seed = 0);
+
+ // A struct to represent a computation instance to be executed.
+ // * If device_handle is not nullptr, the computation is executed on a device
+ // associated with the handle. Otherwise, a device is chosen by the service.
+ // * If shapes_with_output_layout is not nullptr, the given shape and its
+ // layout is used as a hint when storing the output of the computation.
+ // * If execution_profile is not nullptr, the pointed-to ExecutionProfile will
+ // be filled with profile data from the execution of the computation.
+ // * seed is for the random number generator used in the computation.
+ struct ComputationInstance {
+ const Computation& computation;
+ std::vector<GlobalData*> arguments;
+ const DeviceHandle* device_handle;
+ const Shape* shape_with_output_layout;
+ ExecutionProfile* execution_profile;
+ uint64 seed;
+ };
+
+ // Executes a list ComputationInstances and returns global data produced from
+ // each computation.
+ StatusOr<std::vector<std::unique_ptr<GlobalData>>> ExecuteParallel(
+ tensorflow::gtl::ArraySlice<ComputationInstance> computations);
+
+ // Requests device_count device handles available on the target. The returned
+ // device handles are used to specify the devices to execute the computations
+ // (see ExecuteParallel) or to transfer data (see TransferToServer or
+ // TransferToInfeed).
+ StatusOr<std::vector<DeviceHandle>> GetDeviceHandles(int64 device_count);
+
+ // Executes the given computation as above Execute(), but launches the
+ // computation asynchronously and returns before the execution is complete.
+ // Returns an ExecutionHandle that represents the launched execution, which is
+ // used to call WaitForExecution() to wait for the execution's completion.
+ StatusOr<ExecutionHandle> ExecuteAsync(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout = nullptr, uint64 seed = 0);
+
+ // Waits until the given asynchronously launched execution of the computation
+ // is complete and returns the execution result. Once this is called, the
+ // given execution handle is no longer valid. If execution_profile is not
+ // nullptr then the pointed-to ExecutionProfile will be filled with profile
+ // data from the execution.
+ StatusOr<std::unique_ptr<GlobalData>> WaitForExecution(
+ const Computation& computation, const ExecutionHandle& execution,
+ ExecutionProfile* execution_profile = nullptr);
+
+ // Transfer the global data provided to this client process, which is
+ // returned in the provided literal. Use sparingly to avoid transfer
+ // overheads.
+ //
+ // If shape_with_layout is not nullptr, it points to a shape whose layout will
+ // be the layout of the returned literal.
+ StatusOr<std::unique_ptr<Literal>> Transfer(
+ const GlobalData& data, const Shape* shape_with_layout = nullptr);
+
+ // Transfer the given literal to the server. This allocates memory on the
+ // device and copies the literal's contents over. Returns a global data handle
+ // that can be used to refer to this value from the client.
+ //
+ // If device_handle is not nullptr, data is transferred to the associated
+ // device (and its replicas if replication is enabled). Otherwise, data is
+ // transferred to the default device (and its replicas).
+ StatusOr<std::unique_ptr<GlobalData>> TransferToServer(
+ const Literal& literal, const DeviceHandle* device_handle = nullptr);
+
+ // Transfer the given literal to the Infeed interface of the device.
+ //
+ // device_handle and replica_id together specify a particular device; a device
+ // assigned for the given replica_id among the replicas that the given device
+ // handle belongs to.
+ Status TransferToInfeed(const Literal& literal, int64 replica_id = 0,
+ const DeviceHandle* device_handle = nullptr);
+
+ // Resets the device, clearing all existing state on the device.
+ Status ResetDevice();
+
+ // Executes the computation with the given arguments and transfers the result
+ // to the client as a literal. Parameters are defined the same as for
+ // Execute() and Transfer().
+ StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout = nullptr,
+ ExecutionProfile* execution_profile = nullptr, uint64 seed = 0);
+
+ // Unregister the memory for the given GlobalData on the device.
+ Status Unregister(const GlobalData& data);
+
+ // Returns a vector of global data handles that point to the tuple elements.
+ StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple(
+ const GlobalData& computation);
+
+ // Retrieves the statistics of the given computation.
+ StatusOr<ComputationStats> GetComputationStats(
+ const Computation& computation) const;
+
+ // Returns the Shape of the given array specified by 'data'. The shape
+ // includes the Layout of the array as it is stored on the service. The layout
+ // information is useful for calling TransferInProcess.
+ StatusOr<Shape> GetShape(const GlobalData& data);
+
+ // As above, but returns the shape of the provided computation (parameter
+ // types/names and return type).
+ StatusOr<std::unique_ptr<ProgramShape>> GetComputationShape(
+ const Computation& computation);
+
+ // Creates a channel handle that can be used to transfer data between
+ // two computations via a pair of Send and Recv instructions.
+ StatusOr<ChannelHandle> CreateChannelHandle();
+
+ // If the service is running in the same process as the client then the
+ // following "InProcess" transfer methods may be used. These methods enable
+ // more efficient transfer of arrays to and from the service.
+
+ // Transfer array from the service into the given buffer. The buffer must be
+ // large enough to hold the array. The array is copied verbatim (memcpy) from
+ // the service. The method GetShape should be called ahead of time
+ // to get the shape and layout of the array as it is stored in the
+ // service. The shape and layout can be used to determine how large the buffer
+ // needs to be.
+ Status TransferInProcess(const GlobalData& data, void* destination);
+
+ // Transfer array to the service from the given buffer with the given shape
+ // and layout. The service creates an internal copy of the data so the client
+ // can free the buffer when this method returns.
+ StatusOr<std::unique_ptr<GlobalData>> TransferToServerInProcess(
+ const Shape& shape, const void* buffer);
+
+ StatusOr<Computation> LoadSnapshot(const SessionModule& module);
+
+ ServiceInterface* stub() { return stub_; }
+
+ private:
+ // Returns the execution statistics (e.g., gflop/s) as a string from the
+ // ExecutionProfile returned from an execution of the computation.
+ StatusOr<string> ExecutionStatsAsString(const Computation& computation,
+ const ExecutionProfile& profile);
+
+ ServiceInterface* stub_; // Stub that this client is connected on.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Client);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_H_
diff --git a/tensorflow/compiler/xla/client/client_library.cc b/tensorflow/compiler/xla/client/client_library.cc
new file mode 100644
index 0000000000..93437023bc
--- /dev/null
+++ b/tensorflow/compiler/xla/client/client_library.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+LocalClientOptions& LocalClientOptions::set_platform(
+ perftools::gputools::Platform* platform) {
+ platform_ = platform;
+ return *this;
+}
+
+perftools::gputools::Platform* LocalClientOptions::platform() const {
+ return platform_;
+}
+
+LocalClientOptions& LocalClientOptions::set_number_of_replicas(
+ int number_of_replicas) {
+ number_of_replicas_ = number_of_replicas;
+ return *this;
+}
+
+int LocalClientOptions::number_of_replicas() const {
+ return number_of_replicas_;
+}
+
+/* static */ ClientLibrary& ClientLibrary::Singleton() {
+ static ClientLibrary* c = new ClientLibrary;
+ return *c;
+}
+
+ClientLibrary::ClientLibrary() = default;
+ClientLibrary::~ClientLibrary() = default;
+
+/* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
+ perftools::gputools::Platform* platform) {
+ LocalClientOptions default_options;
+ default_options.set_platform(platform);
+ return GetOrCreateLocalClient(default_options);
+}
+
+/* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
+ const LocalClientOptions& options) {
+ perftools::gputools::Platform* platform = options.platform();
+ int replica_count = options.number_of_replicas();
+ ClientLibrary& client_library = Singleton();
+ tensorflow::mutex_lock lock(client_library.service_mutex_);
+
+ if (platform == nullptr) {
+ TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
+ }
+
+ auto it = client_library.instances_.find(platform->id());
+ if (it != client_library.instances_.end()) {
+ return it->second->client.get();
+ }
+
+ ServiceOptions service_options;
+ service_options.set_platform(platform);
+ service_options.set_number_of_replicas(replica_count);
+
+ std::unique_ptr<LocalInstance> instance = MakeUnique<LocalInstance>();
+ TF_ASSIGN_OR_RETURN(instance->service,
+ LocalService::NewService(service_options));
+ instance->client = MakeUnique<LocalClient>(instance->service.get());
+ LocalClient* cl = instance->client.get();
+
+ client_library.instances_.insert(
+ std::make_pair(platform->id(), std::move(instance)));
+ return cl;
+}
+
+/* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
+ auto client_status = GetOrCreateLocalClient();
+ TF_CHECK_OK(client_status.status());
+ return client_status.ValueOrDie();
+}
+
+/* static */ LocalService* ClientLibrary::GetXlaService(
+ perftools::gputools::Platform* platform) {
+ ClientLibrary& client_library = Singleton();
+ tensorflow::mutex_lock lock(client_library.service_mutex_);
+ auto it = client_library.instances_.find(platform->id());
+ CHECK(it != client_library.instances_.end());
+ return it->second->service.get();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/client_library.h b/tensorflow/compiler/xla/client/client_library.h
new file mode 100644
index 0000000000..2bc319f933
--- /dev/null
+++ b/tensorflow/compiler/xla/client/client_library.h
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The "client library" instantiates a local (in-process) XLA service for
+// use by this process, and connects to it with a singleton XLA local
+// client. ClientLibrary::GetOrCreateLocalClient will spawn a local service,
+// and return a client that's connected to it and ready to run XLA
+// computations.
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+
+// Options to configure the local client when it is created.
+class LocalClientOptions {
+ public:
+ // Set the platform backing the service, or nullptr for the default platform.
+ LocalClientOptions& set_platform(perftools::gputools::Platform* platform);
+ perftools::gputools::Platform* platform() const;
+
+ // Set the number of replicas to use when compiling replicated
+ // programs. The default is -1 meaning that the value is read from
+ // the xla_replicas flag.
+ LocalClientOptions& set_number_of_replicas(int number_of_replicas);
+ int number_of_replicas() const;
+
+ private:
+ perftools::gputools::Platform* platform_ = nullptr;
+ int number_of_replicas_ = -1;
+};
+
+class ClientLibrary {
+ public:
+ // Singleton constructor-or-accessor -- returns a client for the application
+ // to issue XLA commands on. Arguments:
+ //
+ // platform : The platform the underlying XLA service should target. If
+ // null then default platform is used.
+ static StatusOr<LocalClient*> GetOrCreateLocalClient(
+ perftools::gputools::Platform* platform = nullptr);
+ static StatusOr<LocalClient*> GetOrCreateLocalClient(
+ const LocalClientOptions& options);
+
+ // Convenience "or-die" wrapper around the above which returns the existing
+ // client library or creates one with default platform and allocator.
+ static LocalClient* LocalClientOrDie();
+
+ // Returns the service from the service thread. Only used in unit tests to
+ // access user computations from client.
+ static LocalService* GetXlaService(perftools::gputools::Platform* platform);
+
+ private:
+ // Returns the singleton instance of ClientLibrary.
+ static ClientLibrary& Singleton();
+
+ ClientLibrary();
+ ~ClientLibrary();
+
+ struct LocalInstance {
+ // Service that is wrapped by the singleton client object.
+ std::unique_ptr<LocalService> service;
+ // Singleton client object.
+ std::unique_ptr<LocalClient> client;
+ };
+
+ tensorflow::mutex service_mutex_; // Guards the singleton creation state.
+ std::unordered_map<perftools::gputools::Platform::Id,
+ std::unique_ptr<LocalInstance>>
+ instances_ GUARDED_BY(service_mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ClientLibrary);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_CLIENT_LIBRARY_H_
diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc
new file mode 100644
index 0000000000..cd7d8df58b
--- /dev/null
+++ b/tensorflow/compiler/xla/client/computation.cc
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/computation.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+Computation::Computation() : parent_(nullptr) {}
+
+Computation::Computation(ServiceInterface* parent,
+ const ComputationHandle& handle)
+ : handle_(handle), parent_(parent) {}
+
+Computation::Computation(Computation&& computation)
+ : handle_(computation.handle_), parent_(computation.parent_) {
+ computation.ResetWithoutFreeing();
+}
+
+void Computation::Reset() {
+ // TODO(leary) deallocate any owned computation.
+ ResetWithoutFreeing();
+}
+
+StatusOr<std::unique_ptr<SessionModule>> Computation::Snapshot() const {
+ SnapshotComputationRequest request;
+ *request.mutable_computation() = handle_;
+ SnapshotComputationResponse response;
+
+ TF_RETURN_IF_ERROR(parent_->SnapshotComputation(&request, &response));
+
+ return WrapUnique(response.release_module());
+}
+
+Computation::~Computation() { Reset(); }
+
+Computation& Computation::operator=(Computation&& computation) {
+ if (&computation != this) {
+ Reset();
+ handle_ = computation.handle_;
+ parent_ = computation.parent_;
+ computation.ResetWithoutFreeing();
+ }
+ return *this;
+}
+
+void Computation::ResetWithoutFreeing() {
+ handle_.Clear();
+ parent_ = nullptr;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h
new file mode 100644
index 0000000000..b595172486
--- /dev/null
+++ b/tensorflow/compiler/xla/client/computation.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Wraps a ComputationHandle protobuf with a lifetime. Computation is
+// movable and not copyable to capture the same kind of unique
+// ownership that std::unique_ptr represents.
+class Computation {
+ public:
+ // Creates a null Computation.
+ Computation();
+
+ // parent: stub for the service on which we will deallocate the computation
+ // when it is no longer needed.
+ // handle: the computation handle protobuf from the service.
+ Computation(ServiceInterface* parent, const ComputationHandle& handle);
+
+ Computation(Computation&& computation);
+
+ // Deallocates the computation.
+ ~Computation();
+
+ Computation& operator=(Computation&& computation);
+
+ // Returns the underlying handle.
+ const ComputationHandle& handle() const { return handle_; }
+
+ // Sets handle to a null state and clears any owned computation.
+ void Reset();
+
+ // Requests that we snapshot the computation into a serializable protocol
+ // buffer form.
+ StatusOr<std::unique_ptr<SessionModule>> Snapshot() const;
+
+ // Returns true if this object is a null Computation.
+ bool IsNull() const { return parent_ == nullptr; }
+
+ private:
+ void ResetWithoutFreeing();
+
+ ComputationHandle handle_; // Handle that is wrapped by this class.
+
+ // Stub that the handle is deallocated on when this object's lifetime ends.
+ ServiceInterface* parent_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Computation);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_H_
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
new file mode 100644
index 0000000000..2b8b0b6ae5
--- /dev/null
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -0,0 +1,1539 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+
+#include <stddef.h>
+#include <array>
+#include <numeric>
+#include <set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+
+ComputationDataHandle ComputationBuilder::ParseOpResponse(
+ const Status& status, OpResponse* response) {
+ VLOG(2) << "done with op request";
+
+ if (!status.ok()) {
+ NoteError(status);
+ return ComputationDataHandle();
+ }
+
+ if (response->output().handle() == 0) {
+ NoteError(InternalError("No output handle"));
+ return ComputationDataHandle();
+ }
+ return response->output();
+}
+
+ComputationBuilder::ComputationBuilder(Client* client,
+ const string& computation_name)
+ : name_(computation_name), first_error_(Status::OK()), client_(client) {}
+
+ComputationBuilder::~ComputationBuilder() {}
+
+void ComputationBuilder::NoteError(const Status& error) {
+ if (die_immediately_on_error_) {
+ LOG(FATAL) << "error building computation: " << error;
+ }
+
+ if (first_error_.ok()) {
+ first_error_ = error;
+ first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
+ }
+}
+
+std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
+ const string& computation_name) {
+ auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
+ sub_builder->parent_builder_ = this;
+ sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
+ return sub_builder;
+}
+
+Status ComputationBuilder::PrepareComputation() {
+ if (!first_error_.ok()) {
+ return first_error_;
+ }
+ if (!computation_.IsNull()) {
+ return Status::OK();
+ }
+
+ ComputationRequest request;
+ request.set_name(name_);
+ ComputationResponse response;
+
+ VLOG(2) << "making computation request";
+ Status s = client_->stub()->Computation(&request, &response);
+ VLOG(2) << "done with computation request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ return first_error_;
+ }
+
+ computation_ = Computation(client_->stub(), response.computation());
+ return Status::OK();
+}
+
+bool ComputationBuilder::MakeWindow(
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
+ const auto verify_size = [&](const int64 x, const char* x_name) {
+ if (x == 0 || x == window_dimensions.size()) {
+ return true;
+ } else {
+ NoteError(InvalidArgument(
+ "%s",
+ tensorflow::strings::StrCat(
+ "Window has different number of window dimensions than of ",
+ x_name, "\nNumber of window dimensions: ",
+ window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
+ "\n")
+ .c_str())); //
+ return false;
+ }
+ };
+ if (!verify_size(window_strides.size(), "window strides") ||
+ !verify_size(padding.size(), "padding entries") ||
+ !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
+ !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
+ return false;
+ }
+
+ window->Clear();
+ for (size_t i = 0; i < window_dimensions.size(); i++) {
+ auto dim = window->add_dimensions();
+ dim->set_size(window_dimensions[i]);
+ if (!window_strides.empty()) {
+ dim->set_stride(window_strides[i]);
+ } else {
+ dim->set_stride(1);
+ }
+ if (!padding.empty()) {
+ dim->set_padding_low(padding[i].first);
+ dim->set_padding_high(padding[i].second);
+ } else {
+ dim->set_padding_low(0);
+ dim->set_padding_high(0);
+ }
+ if (!lhs_dilation.empty()) {
+ dim->set_base_dilation(lhs_dilation[i]);
+ } else {
+ dim->set_base_dilation(1);
+ }
+ if (!rhs_dilation.empty()) {
+ dim->set_window_dilation(rhs_dilation[i]);
+ } else {
+ dim->set_window_dilation(1);
+ }
+ }
+ return true;
+}
+
+ComputationDataHandle ComputationBuilder::ConstantOp(
+ const PopulateLiteral& populate) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ConstantRequest request;
+ Literal* literal = request.mutable_literal();
+ populate(literal);
+ VLOG(3) << "created constant: " << literal->ShortDebugString();
+ OpRequest op_request;
+ *op_request.mutable_constant_request() = request;
+ *op_request.mutable_computation() = computation_.handle();
+ OpResponse response;
+
+ VLOG(2) << "making constant request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::ConstantLiteral(
+ const Literal& literal) {
+ return ConstantOp(
+ [literal](Literal* mutable_literal) { *mutable_literal = literal; });
+}
+
+ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
+ const Shape& shape,
+ const string& name) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ParameterRequest request;
+ *request.mutable_shape() = shape;
+ request.set_parameter(parameter_number);
+ request.set_name(name);
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_parameter_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making parameter request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
+ const ComputationDataHandle& operand) {
+ if (!first_error_.ok()) {
+ return first_error_;
+ }
+
+ GetLocalShapeRequest request;
+ *request.mutable_computation() = computation_.handle();
+ *request.mutable_operand() = operand;
+ GetLocalShapeResponse response;
+
+ VLOG(2) << "making get-shape request";
+ Status s = client_->stub()->GetLocalShape(&request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ return first_error_;
+ }
+ TF_RET_CHECK(response.has_shape());
+ std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
+ TF_RET_CHECK(shape != nullptr);
+ return std::move(shape);
+}
+
+ComputationDataHandle ComputationBuilder::CheckShape(
+ const ComputationDataHandle& operand, const Shape& expected_shape) {
+ std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
+ CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
+ << "want " << ShapeUtil::HumanString(expected_shape) << " got "
+ << ShapeUtil::HumanString(*actual_shape);
+ return operand;
+}
+
+void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs) {
+ std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
+ std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
+ VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
+ << ShapeUtil::HumanString(*rhs_shape);
+ CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
+ << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
+ << ShapeUtil::HumanString(*rhs_shape);
+}
+
+ComputationDataHandle ComputationBuilder::Slice(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ SliceRequest request;
+ *request.mutable_operand() = operand;
+ for (int64 index : start_indices) {
+ request.add_start_indices(index);
+ }
+ for (int64 index : limit_indices) {
+ request.add_limit_indices(index);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_slice_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making slice request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::DynamicSlice(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ DynamicSliceRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_start_indices() = start_indices;
+ for (int64 index : slice_sizes) {
+ request.add_slice_sizes(index);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_dynamic_slice_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making dynamic slice request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
+ const ComputationDataHandle& operand, const ComputationDataHandle& update,
+ const ComputationDataHandle& start_indices) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ DynamicUpdateSliceRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_update() = update;
+ *request.mutable_start_indices() = start_indices;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_dynamic_update_slice_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making dynamic update slice request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::ConcatInDim(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ int64 dimension) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ConcatenateRequest request;
+ for (const ComputationDataHandle& operand : operands) {
+ *request.add_operands() = operand;
+ }
+ request.set_dimension(dimension);
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_concatenate_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making concatenate request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Broadcast(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ BroadcastRequest request;
+ *request.mutable_operand() = operand;
+ for (int64 size : broadcast_sizes) {
+ request.add_broadcast_sizes(size);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_broadcast_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making broadcast request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Pad(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& padding_value,
+ const PaddingConfig& padding_config) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ PadRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_padding_value() = padding_value;
+ *request.mutable_padding_config() = padding_config;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_pad_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making pad request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Reshape(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ReshapeRequest request;
+ *request.mutable_operand() = operand;
+ for (int64 dimension : dimensions) {
+ request.add_dimensions(dimension);
+ }
+ for (int64 new_size : new_sizes) {
+ request.add_new_sizes(new_size);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_reshape_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making reshape request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Reshape(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ if (!first_error_.ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
+ if (!shape.ok()) {
+ // Just early return with the existing error status.
+ first_error_ = shape.status();
+ return ComputationDataHandle();
+ }
+ std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
+ std::iota(dimensions.begin(), dimensions.end(), 0);
+ return Reshape(operand, dimensions, new_sizes);
+}
+
+ComputationDataHandle ComputationBuilder::Collapse(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dims_to_collapse) {
+ if (!first_error_.ok()) {
+ return ComputationDataHandle();
+ }
+
+ // Don't support out-of-order collapse here.
+ // Checks that the collapsed dimensions are in order and consecutive.
+ for (int i = 1; i < dims_to_collapse.size(); ++i) {
+ if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) {
+ NoteError(InvalidArgument(
+ "Collapsed dimensions are not in order and consecutive."));
+ return ComputationDataHandle();
+ }
+ }
+
+ // Create a new sizes vector from the old shape, replacing the collapsed
+ // dimensions by the product of their sizes.
+ StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
+ if (!shape_or_status.ok()) {
+ // Just early return with the existing error status.
+ first_error_ = shape_or_status.status();
+ return ComputationDataHandle();
+ }
+ std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
+
+ std::vector<int64> new_sizes;
+ for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
+ if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) {
+ new_sizes.push_back(original_shape->dimensions(i));
+ } else {
+ new_sizes.back() *= original_shape->dimensions(i);
+ }
+ }
+
+ return Reshape(operand, new_sizes);
+}
+
+void ComputationBuilder::Trace(const string& tag,
+ const ComputationDataHandle& operand) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return;
+ }
+
+ TraceRequest request;
+ request.set_tag(tag);
+ *request.mutable_operand() = operand;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_trace_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making trace request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ }
+}
+
+ComputationDataHandle ComputationBuilder::Select(
+ const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
+ const ComputationDataHandle& on_false) {
+ return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
+}
+
+ComputationDataHandle ComputationBuilder::Tuple(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ VariadicOpRequest request;
+ request.set_varop(VAROP_TUPLE);
+ for (const ComputationDataHandle& operand : elements) {
+ *request.add_operands() = operand;
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_variadic_op_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making variadic op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::GetTupleElement(
+ const ComputationDataHandle& tuple_data, int64 index) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ GetTupleElementRequest request;
+ *request.mutable_operand() = tuple_data;
+ request.set_index(index);
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_get_tuple_element_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making get tuple element op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Eq(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Ne(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Ge(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Gt(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Le(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Lt(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Dot(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
+ return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{});
+}
+
+ComputationDataHandle ComputationBuilder::Conv(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ return ConvWithGeneralDimensions(
+ lhs, rhs, window_strides, padding,
+ CreateDefaultConvDimensionNumbers(window_strides.size()));
+}
+
+ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ return ConvGeneral(lhs, rhs, window_strides, padding,
+ CreateDefaultConvDimensionNumbers(window_strides.size()));
+}
+
+bool ComputationBuilder::VerifyConvolution(
+ const Shape& lhs_shape, const Shape& rhs_shape,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
+ NoteError(
+ InvalidArgument("Convolution arguments must have same number of "
+ "dimensions. Got: %s and %s",
+ ShapeUtil::HumanString(lhs_shape).c_str(),
+ ShapeUtil::HumanString(rhs_shape).c_str()));
+ return false;
+ }
+ int num_dims = ShapeUtil::Rank(lhs_shape);
+ if (num_dims < 3) {
+ NoteError(InvalidArgument(
+ "Convolution expects argument arrays with >= 3 dimensions. "
+ "Got: %s and %s",
+ ShapeUtil::HumanString(lhs_shape).c_str(),
+ ShapeUtil::HumanString(rhs_shape).c_str()));
+ return false;
+ }
+ int num_spatial_dims = num_dims - 2;
+
+ const auto check_spatial_dimensions = [&](
+ const char* const field_name,
+ const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
+ numbers) {
+ if (numbers.size() != num_spatial_dims) {
+ NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
+ num_spatial_dims, field_name, numbers.size()));
+ return false;
+ }
+ for (int i = 0; i < numbers.size(); ++i) {
+ if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
+ NoteError(InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
+ field_name, i, numbers.Get(i)));
+ return false;
+ }
+ }
+ return true;
+ };
+ return check_spatial_dimensions("spatial_dimensions",
+ dimension_numbers.spatial_dimensions()) &&
+ check_spatial_dimensions(
+ "kernel_spatial_dimensions",
+ dimension_numbers.kernel_spatial_dimensions());
+}
+
+ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
+ if (!lhs_shape_or_status.ok()) {
+ first_error_ = lhs_shape_or_status.status();
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
+ if (!rhs_shape_or_status.ok()) {
+ first_error_ = rhs_shape_or_status.status();
+ return ComputationDataHandle();
+ }
+
+ std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
+ std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
+
+ if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
+ NoteError(InternalError("failed to verify convolution"));
+ return ComputationDataHandle();
+ }
+
+ std::vector<int64> base_area_dimensions(
+ dimension_numbers.spatial_dimensions_size());
+ for (int i = 0; i < base_area_dimensions.size(); ++i) {
+ base_area_dimensions[i] =
+ lhs_shape->dimensions(dimension_numbers.spatial_dimensions(i));
+ }
+
+ std::vector<int64> window_dimensions(
+ dimension_numbers.kernel_spatial_dimensions_size());
+ for (int i = 0; i < window_dimensions.size(); ++i) {
+ window_dimensions[i] =
+ rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
+ }
+
+ return ConvGeneral(lhs, rhs, window_strides,
+ MakePadding(base_area_dimensions, window_dimensions,
+ window_strides, padding),
+ dimension_numbers);
+}
+
+ComputationDataHandle ComputationBuilder::ConvGeneral(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
+ dimension_numbers);
+}
+
+ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
+ if (!lhs_shape_or_status.ok()) {
+ first_error_ = lhs_shape_or_status.status();
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
+ if (!rhs_shape_or_status.ok()) {
+ first_error_ = rhs_shape_or_status.status();
+ return ComputationDataHandle();
+ }
+
+ std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
+ std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
+ if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
+ // Error is recorded in VerifyConvolution.
+ return ComputationDataHandle();
+ }
+
+ std::vector<int64> window_dimensions(
+ dimension_numbers.kernel_spatial_dimensions_size());
+ for (int i = 0; i < window_dimensions.size(); ++i) {
+ window_dimensions[i] =
+ rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
+ }
+
+ ConvolveRequest request;
+ *request.mutable_lhs() = lhs;
+ *request.mutable_rhs() = rhs;
+ *request.mutable_dimension_numbers() = dimension_numbers;
+
+ if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
+ rhs_dilation, request.mutable_window())) {
+ // Error is recorded in MakeWindow.
+ return ComputationDataHandle();
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_convolve_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making convolve request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ InfeedRequest request;
+ *request.mutable_shape() = shape;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_infeed_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making infeed op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Call(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ CallRequest request;
+ *request.mutable_to_apply() = computation.handle();
+ for (const ComputationDataHandle& operand : operands) {
+ *request.add_operands() = operand;
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_call_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making call op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::CustomCall(
+ const string& call_target_name,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const Shape& shape) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ CustomCallRequest request;
+ request.set_call_target_name(call_target_name);
+ for (const ComputationDataHandle& operand : operands) {
+ *request.add_operands() = operand;
+ }
+ *request.mutable_shape() = shape;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_custom_call_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making custom call op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Add(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Sub(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Mul(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Div(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Rem(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Max(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::Min(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::LogicalAnd(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_LOGICAL_AND, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::LogicalOr(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ return BinaryOp(BINOP_LOGICAL_OR, lhs, rhs, broadcast_dimensions);
+}
+
+ComputationDataHandle ComputationBuilder::LogicalNot(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_LOGICAL_NOT, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Abs(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_ABS, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Exp(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_EXP, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Floor(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_FLOOR, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Ceil(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_CEIL, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Log(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_LOG, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Sign(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_SIGN, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Tanh(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_TANH, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Transpose(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation) {
+ if (!first_error_.ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
+ if (!shape.ok()) {
+ // Just early return with the existing error status.
+ first_error_ = shape.status();
+ return ComputationDataHandle();
+ }
+ return Reshape(operand, permutation,
+ Permute(InversePermutation(permutation),
+ AsInt64Slice(shape.ValueOrDie()->dimensions())));
+}
+
+ComputationDataHandle ComputationBuilder::Rev(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ReverseRequest request;
+ *request.mutable_operand() = operand;
+ for (int64 dimension : dimensions) {
+ request.add_dimensions(dimension);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_reverse_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making reverse op request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Sort(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_SORT, operand);
+}
+
+ComputationDataHandle ComputationBuilder::SqrtF32(
+ const ComputationDataHandle& operand) {
+ return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
+ /*broadcast_dimensions=*/{});
+}
+
+ComputationDataHandle ComputationBuilder::Pow(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
+ return BinaryOp(BINOP_POW, lhs, rhs, /*broadcast_dimensions=*/{});
+}
+
+ComputationDataHandle ComputationBuilder::ConvertElementType(
+ const ComputationDataHandle& operand, PrimitiveType new_element_type) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
+ if (!shape_status.ok()) {
+ // Just early return with the existing error status.
+ first_error_ = shape_status.status();
+ return ComputationDataHandle();
+ }
+ std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
+
+ ConvertRequest request;
+ *request.mutable_operand() = operand;
+ request.set_new_element_type(new_element_type);
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_convert_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making convert request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::SquareF32(
+ const ComputationDataHandle& operand) {
+ return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
+ /*broadcast_dimensions=*/{});
+}
+
+ComputationDataHandle ComputationBuilder::ReciprocalF32(
+ const ComputationDataHandle& operand) {
+ return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
+ /*broadcast_dimensions=*/{});
+}
+
+ComputationDataHandle ComputationBuilder::Neg(
+ const ComputationDataHandle& operand) {
+ return UnaryOp(UNOP_NEGATE, operand);
+}
+
+ComputationDataHandle ComputationBuilder::Clamp(
+ const ComputationDataHandle& min, const ComputationDataHandle& operand,
+ const ComputationDataHandle& max) {
+ return TernaryOp(TRIOP_CLAMP, min, operand, max);
+}
+
+ComputationDataHandle ComputationBuilder::UnaryOp(
+ UnaryOperation unop, const ComputationDataHandle& operand) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ UnaryOpRequest request;
+ request.set_unop(unop);
+ *request.mutable_operand() = operand;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_unary_op_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making unop request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::BinaryOp(
+ BinaryOperation binop, const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ BinaryOpRequest request;
+ request.set_binop(binop);
+ *request.mutable_lhs() = lhs;
+ *request.mutable_rhs() = rhs;
+ for (int64 dimension : broadcast_dimensions) {
+ request.add_broadcast_dimensions(dimension);
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_binary_op_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making binop request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::RngOp(
+ RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
+ const Shape& shape) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ RngRequest request;
+ request.set_distribution(distribution);
+ for (const ComputationDataHandle& param : parameters) {
+ *request.add_parameter() = param;
+ }
+ *request.mutable_shape() = shape;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_rng_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making rngop request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::TernaryOp(
+ TernaryOperation triop, const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ TernaryOpRequest request;
+ request.set_triop(triop);
+ *request.mutable_lhs() = lhs;
+ *request.mutable_rhs() = rhs;
+ *request.mutable_ehs() = ehs;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_ternary_op_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making triop request";
+ Status s = client_->stub()->Op(&op_request, &response);
+
+ return ParseOpResponse(s, &response);
+}
+
+Status ComputationBuilder::SetReturnValue(
+ const ComputationDataHandle& operand) {
+ if (!first_error_.ok()) {
+ return first_error_;
+ }
+
+ SetReturnValueRequest request;
+ *request.mutable_computation() = computation_.handle();
+ *request.mutable_operand() = operand;
+
+ SetReturnValueResponse response;
+
+ VLOG(2) << "making set-handle-to-execute request";
+ Status s = client_->stub()->SetReturnValue(&request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ return first_error_;
+ }
+
+ return Status::OK();
+}
+
+StatusOr<bool> ComputationBuilder::IsConstant(
+ const ComputationDataHandle& operand) {
+ if (!first_error_.ok()) {
+ return first_error_;
+ }
+
+ IsConstantRequest request;
+ *request.mutable_computation() = computation_.handle();
+ *request.mutable_operand() = operand;
+ IsConstantResponse response;
+
+ VLOG(2) << "making IsConstant request";
+ Status s = client_->stub()->IsConstant(&request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ return first_error_;
+ }
+ return response.is_constant();
+}
+
+StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
+ const ComputationDataHandle& operand, const Layout* output_layout) {
+ if (!first_error_.ok()) {
+ return first_error_;
+ }
+
+ ComputeConstantRequest request;
+ *request.mutable_computation() = computation_.handle();
+ *request.mutable_operand() = operand;
+ if (output_layout != nullptr) {
+ *request.mutable_output_layout() = *output_layout;
+ }
+
+ ComputeConstantResponse response;
+
+ VLOG(2) << "making compute-constant request";
+ Status s = client_->stub()->ComputeConstant(&request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ return first_error_;
+ }
+
+ TF_RET_CHECK(response.output().handle() != 0);
+ return MakeUnique<GlobalData>(client_->stub(), response.output());
+}
+
+ComputationDataHandle ComputationBuilder::Map(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ MapRequest request;
+ for (const ComputationDataHandle& operand : operands) {
+ *request.add_operands() = operand;
+ }
+ *request.mutable_to_apply() = computation.handle();
+ for (const ComputationDataHandle& sop : static_operands) {
+ *request.add_static_operands() = sop;
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_map_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making Map request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::RngNormal(
+ const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
+ const Shape& shape) {
+ return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
+}
+
+ComputationDataHandle ComputationBuilder::RngUniform(
+ const ComputationDataHandle& a, const ComputationDataHandle& b,
+ const Shape& shape) {
+ return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
+}
+
+ComputationDataHandle ComputationBuilder::RngBernoulli(
+ const ComputationDataHandle& mean, const Shape& shape) {
+ return RngOp(RandomDistribution::RNG_BERNOULLI, {mean}, shape);
+}
+
+ComputationDataHandle ComputationBuilder::While(
+ const Computation& condition, const Computation& body,
+ const ComputationDataHandle& init) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ WhileRequest request;
+ *request.mutable_condition() = condition.handle();
+ *request.mutable_body() = body.handle();
+ *request.mutable_init() = init;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_while_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making while request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::Reduce(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value, const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ReduceRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_init_value() = init_value;
+ for (int64 dimension : dimensions_to_reduce) {
+ request.add_dimensions(dimension);
+ }
+ *request.mutable_to_apply() = computation.handle();
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_reduce_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making reduce request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::ReduceWindow(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value, const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ if (!first_error_.ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
+ if (!shape.ok()) {
+ // Just early return with the existing error status.
+ first_error_ = shape.status();
+ return ComputationDataHandle();
+ }
+
+ return ReduceWindowWithGeneralPadding(
+ operand, init_value, computation, window_dimensions, window_strides,
+ MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
+ window_dimensions, window_strides, padding));
+}
+
+ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value, const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ ReduceWindowRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_to_apply() = computation.handle();
+ *request.mutable_init_value() = init_value;
+
+ if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
+ request.mutable_window())) {
+ NoteError(InternalError("failed to make window"));
+ return ComputationDataHandle();
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_reduce_window_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making reduce-window request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::CrossReplicaSum(
+ const ComputationDataHandle& operand) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ CrossReplicaSumRequest request;
+ *request.mutable_operand() = operand;
+ OpRequest op_request;
+ *op_request.mutable_cross_replica_sum_request() = request;
+ *op_request.mutable_computation() = computation_.handle();
+ OpResponse response;
+
+ VLOG(2) << "making cross-replica-sum request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+ComputationDataHandle ComputationBuilder::SelectAndScatter(
+ const ComputationDataHandle& operand, const Computation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ComputationDataHandle& source,
+ const ComputationDataHandle& init_value, const Computation& scatter) {
+ if (!first_error_.ok()) {
+ return ComputationDataHandle();
+ }
+
+ StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
+ if (!shape.ok()) {
+ // Just early return with the existing error status.
+ first_error_ = shape.status();
+ return ComputationDataHandle();
+ }
+ return SelectAndScatterWithGeneralPadding(
+ operand, select, window_dimensions, window_strides,
+ MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
+ window_dimensions, window_strides, padding),
+ source, init_value, scatter);
+}
+
+ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
+ const ComputationDataHandle& operand, const Computation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ComputationDataHandle& source,
+ const ComputationDataHandle& init_value, const Computation& scatter) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ SelectAndScatterRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_select() = select.handle();
+ *request.mutable_source() = source;
+ *request.mutable_init_value() = init_value;
+ *request.mutable_scatter() = scatter.handle();
+
+ if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
+ request.mutable_window())) {
+ NoteError(InternalError("failed to make window"));
+ return ComputationDataHandle();
+ }
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_select_and_scatter_request() = request;
+ OpResponse response;
+
+ VLOG(2) << "making select-and-scatter request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
+}
+
+void ComputationBuilder::Send(const ComputationDataHandle& operand,
+ const ChannelHandle& handle) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return;
+ }
+
+ SendRequest request;
+ *request.mutable_operand() = operand;
+ *request.mutable_channel_handle() = handle;
+ OpRequest op_request;
+ *op_request.mutable_send_request() = request;
+ *op_request.mutable_computation() = computation_.handle();
+ OpResponse response;
+
+ VLOG(2) << "making send request";
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ VLOG(2) << "done with request";
+
+ if (!s.ok()) {
+ NoteError(s);
+ return;
+ }
+}
+
+ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
+ const ChannelHandle& handle) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ RecvRequest request;
+ *request.mutable_shape() = shape;
+ *request.mutable_channel_handle() = handle;
+ OpRequest op_request;
+ *op_request.mutable_recv_request() = request;
+ *op_request.mutable_computation() = computation_.handle();
+ OpResponse response;
+
+ VLOG(2) << "making recv request";
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ VLOG(2) << "done with request";
+
+ return ParseOpResponse(s, &response);
+}
+
+Computation ComputationBuilder::BuildAndNoteError() {
+ DCHECK(parent_builder_ != nullptr);
+ auto build_status = Build();
+ if (!build_status.ok()) {
+ parent_builder_->NoteError(
+ AddStatus(build_status.status(),
+ tensorflow::strings::StrCat("error from: ", name_)));
+ return Computation();
+ }
+ return build_status.ConsumeValueOrDie();
+}
+
+StatusOr<Computation> ComputationBuilder::Build() {
+ if (!first_error_.ok()) {
+ string backtrace;
+ first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
+ return AppendStatus(first_error_, backtrace);
+ }
+
+ if (computation_.IsNull()) {
+ return FailedPrecondition("no computation was built");
+ }
+
+ return {std::move(computation_)};
+}
+
+/* static */ ConvolutionDimensionNumbers
+ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
+ ConvolutionDimensionNumbers dimension_numbers;
+ dimension_numbers.set_batch_dimension(kConvBatchDimension);
+ dimension_numbers.set_feature_dimension(kConvFeatureDimension);
+ dimension_numbers.set_kernel_output_feature_dimension(
+ kConvKernelOutputDimension);
+ dimension_numbers.set_kernel_input_feature_dimension(
+ kConvKernelInputDimension);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ dimension_numbers.add_spatial_dimensions(i + 2);
+ dimension_numbers.add_kernel_spatial_dimensions(i + 2);
+ }
+ return dimension_numbers;
+}
+
+/* static */ StatusOr<ConvolutionDimensionNumbers>
+ComputationBuilder::CreateConvDimensionNumbers(
+ int64 batch, int64 feature, int64 first_spatial, int64 second_spatial,
+ int64 kernel_output_feature, int64 kernel_input_feature,
+ int64 kernel_first_spatial, int64 kernel_second_spatial) {
+ if (std::set<int64>({batch, feature, first_spatial, second_spatial}).size() !=
+ 4) {
+ return FailedPrecondition(
+ "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
+ "%lld)",
+ batch, feature, first_spatial, second_spatial);
+ }
+ if (std::set<int64>({kernel_output_feature, kernel_input_feature,
+ kernel_first_spatial, kernel_second_spatial})
+ .size() != 4) {
+ return FailedPrecondition(
+ "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
+ "%lld)",
+ kernel_output_feature, kernel_input_feature, kernel_first_spatial,
+ kernel_second_spatial);
+ }
+ ConvolutionDimensionNumbers dimension_numbers;
+ dimension_numbers.set_batch_dimension(batch);
+ dimension_numbers.set_feature_dimension(feature);
+ dimension_numbers.add_spatial_dimensions(first_spatial);
+ dimension_numbers.add_spatial_dimensions(second_spatial);
+ dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
+ dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
+ dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
+ dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
+ return dimension_numbers;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
new file mode 100644
index 0000000000..a74257eae3
--- /dev/null
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -0,0 +1,783 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
+
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stacktrace.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Wraps an XLA client with a convenient interface for building up
+// computations. Any errors encountered in building up the computation are
+// deferred from being handled until Build() is called.
+//
+// Thread-compatible.
+class ComputationBuilder {
+ public:
+ // client: client in which to build the computation.
+ // computation_name: name to use for the built computation.
+ ComputationBuilder(Client* client, const string& computation_name);
+
+ ~ComputationBuilder();
+
+ // Returns the client the builder was initialized with.
+ Client* client() { return client_; }
+
+ // Returns the computation name.
+ const string& name() { return name_; }
+
+ // Sets the builder to a mode where it will die immediately when an error is
+ // encountered, rather than producing it in a deferred fashion when Build() is
+ // called (which is the default).
+ void set_die_immediately_on_error(bool enabled) {
+ die_immediately_on_error_ = enabled;
+ }
+
+ // Enqueues a "retrieve parameter value" instruction for a parameter that was
+ // passed to the computation.
+ ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape,
+ const string& name);
+
+ // Retrieves the (inferred) shape of the operand in the computation.
+ StatusOr<std::unique_ptr<Shape>> GetShape(
+ const ComputationDataHandle& operand);
+
+ // Checks that the operand has the given expected shape. Returns the operand
+ // if yes, fails with a CHECK error if no.
+ ComputationDataHandle CheckShape(const ComputationDataHandle& operand,
+ const Shape& expected_shape);
+
+ // Checks that the lhs and rhs results have the same shape.
+ void CheckSameShape(const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs);
+
+ // Enqueues a constant with the value of the given literal onto the
+ // computation.
+ ComputationDataHandle ConstantLiteral(const Literal& literal);
+
+ // Enqueues a constant onto the computation. Methods are templated on the
+ // native host type (NativeT) which corresponds to a specific XLA
+ // PrimitiveType as given in the following table:
+ //
+ // Native Type PrimitiveType
+ // -----------------------------
+ // bool PRED
+ // int32 S32
+ // int64 S64
+ // uint32 U32
+ // uint64 U64
+ // float F32
+ // double F64
+ //
+ // Note: not all primitive types defined in xla_data.proto have a
+ // corresponding native type yet.
+ template <typename NativeT>
+ ComputationDataHandle ConstantR0(NativeT value);
+ template <typename NativeT>
+ ComputationDataHandle ConstantR1(tensorflow::gtl::ArraySlice<NativeT> values);
+ ComputationDataHandle ConstantR1(const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ ComputationDataHandle ConstantR2(
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ ComputationDataHandle ConstantR2FromArray2D(const Array2D<NativeT>& values);
+ template <typename NativeT>
+ ComputationDataHandle ConstantR3FromArray3D(const Array3D<NativeT>& values);
+ template <typename NativeT>
+ ComputationDataHandle ConstantR4FromArray4D(const Array4D<NativeT>& values);
+
+ // Enqueues a rank one constant (vector) onto the computation. The vector has
+ // size 'length' and every element has the value 'value'.
+ template <typename NativeT>
+ ComputationDataHandle ConstantR1(int64 length, NativeT value);
+
+ // Adds dimensions to an array by duplicating the data in the array.
+ //
+ // The new dimensions are inserted on the left, i.e. if
+ // broadcast_sizes has values {a0, ..., aN} and the operand shape
+ // has dimensions {b0, ..., bM} then the shape of the output has
+ // dimensions {a0, ..., aN, b0, ..., bM}.
+ //
+ // The new dimensions index into copies of the operand, i.e.
+ //
+ // output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
+ ComputationDataHandle Broadcast(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ // Enqueues a pad operation onto the computation that pads the given value on
+ // the edges as well as between the elements of the input. padding_config
+ // specifies the padding amount for each dimension.
+ ComputationDataHandle Pad(const ComputationDataHandle& operand,
+ const ComputationDataHandle& padding_value,
+ const PaddingConfig& padding_config);
+
+ // Enqueues an operation onto the computation that flattens the operand based
+ // on the dimension order (major/slowest-varying to minor/fastest-varying)
+ // given, followed by reshaping it into the shape with the given dimension
+ // sizes (also major to minor). Conceptually, this is a limited form of
+ // "shape casting".
+ ComputationDataHandle Reshape(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Enqueues an operation onto the computation that collapses the operand, from
+ // minor to major order, then reshapes it into the shape with the given
+ // dimension sizes, also from major to minor. Conceptually, this is a limited
+ // form of "shape casting".
+ ComputationDataHandle Reshape(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Wrapper for Reshape.
+ // Enqueues an operation to collapse the provided dimensions; e.g. an
+ // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
+ // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
+ // be a consecutive, in-order subsequence of the operand dimensions.
+ //
+ // This could potentially cause data to be moved -- it provides a more
+ // structured form of reshaping than an arbitrary Reshape operation.
+ ComputationDataHandle Collapse(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Enqueues a slice operation onto the computation that slices the operand
+ // from the start indices to the limit indices; e.g.
+ //
+ // x
+ // [ 0 1 2 3 ]
+ // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
+ // [ 8 9 a b ]
+ //
+ // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
+ // range notation.
+ ComputationDataHandle Slice(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices);
+
+ // Enqueues a slice operation onto the computation that slices the 'operand'
+ // from dynamic start indices which are passed in 'start_indices'.
+ // The size of the slice in each dimension is passed in 'slice_sizes',
+ // which specify the end point of exclusive slice intervals in each
+ // dimension [start, start + size).
+ // The shape of 'start_indices' must be rank == 1, with dimension size
+ // equal to the rank of the 'operand'.
+ // Slice index calculations are computed modulo input dimension sizes to
+ // prevent dynamic start indices from generating out-of-bound array accesses.
+ ComputationDataHandle DynamicSlice(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ // Enqueues a dynamic update slice operation onto the computation, which
+ // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
+ // The shape of 'update' determines the shape of the slice of 'operand'
+ // which is updated.
+ // The indices specified in 'start_indices' specify the offset of the slice
+ // of 'operand' which is updated.
+ //
+ // update = {10, 11} // calculated at runtime.
+ // [1 2 3] start = {1, 1} // calculated at runtime. [1 2 3 ]
+ // [4 5 6] => DynamicUpdateslice(data, update, start) => [4 10 11]
+ // [7 8 9] [7 8 9 ]
+ //
+ // The shape of 'start_indices' must be rank == 1, with dimension size
+ // equal to the rank of the 'operand'.
+ // Slice index calculations are computed modulo update dimension sizes to
+ // prevent dynamic start indices from generating out-of-bound array accesses.
+ ComputationDataHandle DynamicUpdateSlice(
+ const ComputationDataHandle& operand, const ComputationDataHandle& update,
+ const ComputationDataHandle& start_indices);
+
+ // Enqueues a concatenate instruction onto the computation. 'operands' must
+ // have >= 1 entry.
+ ComputationDataHandle ConcatInDim(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ int64 dimension);
+
+ // Enqueue a tracing operation onto the computation; the computation will emit
+ // a logging message with the operand.
+ void Trace(const string& tag, const ComputationDataHandle& operand);
+
+ // Enqueues a conditional-move-like select operation onto the computation;
+ // predicated on pred, selects between on_true and on_false.
+ ComputationDataHandle Select(const ComputationDataHandle& pred,
+ const ComputationDataHandle& on_true,
+ const ComputationDataHandle& on_false);
+
+ // Enqueues a tuple-creation instruction onto the computation.
+ ComputationDataHandle Tuple(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> elements);
+
+ // Enqueues a tuple-element-get instruction onto the computation.
+ ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data,
+ int64 index);
+
+ // Enqueues an equal-to comparison instruction onto the computation.
+ ComputationDataHandle Eq(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a not-equal comparison instruction onto the computation.
+ ComputationDataHandle Ne(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a greater-or-equal comparison instruction onto the computation.
+ ComputationDataHandle Ge(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a greater-than comparison instruction onto the computation.
+ ComputationDataHandle Gt(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a less-than comparison instruction onto the computation.
+ ComputationDataHandle Lt(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a less-or-equal comparison instruction onto the computation.
+ ComputationDataHandle Le(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a dot instruction onto the computation.
+ ComputationDataHandle Dot(const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs);
+
+ // Default dimension numbers used for a 2D convolution.
+ static constexpr int64 kConvBatchDimension = 0;
+ static constexpr int64 kConvFeatureDimension = 1;
+ static constexpr int64 kConvFirstSpatialDimension = 2;
+ static constexpr int64 kConvSecondSpatialDimension = 3;
+ static constexpr int64 kConvKernelOutputDimension = 0;
+ static constexpr int64 kConvKernelInputDimension = 1;
+ static constexpr int64 kConvKernelFirstSpatialDimension = 2;
+ static constexpr int64 kConvKernelSecondSpatialDimension = 3;
+
+ // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
+ // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
+ // the kernel operand
+ // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
+ static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
+ int num_spatial_dims = 2);
+
+ // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an
+ // error if either the input or the weight dimension numbers have conflicts.
+ static StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
+ int64 batch, int64 feature, int64 first_spatial, int64 second_spatial,
+ int64 kernel_output_feature, int64 kernel_input_feature,
+ int64 kernel_first_spatial, int64 kernel_second_spatial);
+
+ // Enqueues a convolution instruction onto the computation, which uses the
+ // default convolution dimension numbers.
+ ComputationDataHandle Conv(const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration in the format returned by MakePadding().
+ ComputationDataHandle ConvWithGeneralPadding(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided dimension numbers configuration.
+ ComputationDataHandle ConvWithGeneralDimensions(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration as well as the dimension numbers.
+ ComputationDataHandle ConvGeneral(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues a convolution instruction onto the computation, with the caller
+ // provided padding configuration, dilation factors and dimension numbers.
+ ComputationDataHandle ConvGeneralDilated(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Enqueues an infeed instruction onto the computation, which reads data of
+ // the given shape from the infeed buffer of the device.
+ ComputationDataHandle Infeed(const Shape& shape);
+
+ // Enqueues a call instruction onto the computation.
+ ComputationDataHandle Call(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands);
+
+ // Enqueues a custom call instruction onto the computation.
+ // During code generation, a call instruction is emitted which targets a
+ // symbol with the name |call_target_name|. The |operands| are passed to the
+ // call instruction. |shape| is the resultant shape.
+ ComputationDataHandle CustomCall(
+ const string& call_target_name,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const Shape& shape);
+
+ // The following methods enqueue element-wise binary arithmetic operations
+ // onto the computation. The shapes of the operands have to match unless one
+ // of the operands is a scalar, or an explicit broadcast dimension is given
+ // (see g3doc for more details).
+
+ // Enqueues an add instruction onto the computation.
+ ComputationDataHandle Add(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a subtract instruction onto the computation.
+ ComputationDataHandle Sub(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a multiply instruction onto the computation.
+ ComputationDataHandle Mul(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a divide instruction onto the computation.
+ ComputationDataHandle Div(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a remainder instruction onto the computation.
+ ComputationDataHandle Rem(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a max instruction onto the computation.
+ ComputationDataHandle Max(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Enqueues a min instruction onto the computation.
+ ComputationDataHandle Min(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ // Element-wise logical operators
+ ComputationDataHandle LogicalAnd(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ ComputationDataHandle LogicalOr(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions = {});
+
+ ComputationDataHandle LogicalNot(const ComputationDataHandle& lhs);
+
+ // Reduces an array among the provided dimensions, given "computation" as a
+ // reduction operator.
+ ComputationDataHandle Reduce(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value, const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce);
+
+ // Enqueues a windowed reduce instruction onto the computation.
+ ComputationDataHandle ReduceWindow(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value, const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding);
+
+ // As ReduceWindow(), but the padding is given in the format
+ // returned by MakePadding().
+ ComputationDataHandle ReduceWindowWithGeneralPadding(
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& init_value, const Computation& computation,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
+
+ // Returns the sum of the operand value across all replicas. All replicas
+ // supply one input to the sum and all replicas receive the resulting sum.
+ ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand);
+
+ // Enqueues an operation that scatters the `source` array to the selected
+ // indices of each window.
+ ComputationDataHandle SelectAndScatter(
+ const ComputationDataHandle& operand, const Computation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
+ const ComputationDataHandle& source,
+ const ComputationDataHandle& init_value, const Computation& scatter);
+
+ // As SelectAndScatter(), but the padding is given in the format
+ // returned by MakePadding().
+ ComputationDataHandle SelectAndScatterWithGeneralPadding(
+ const ComputationDataHandle& operand, const Computation& select,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ const ComputationDataHandle& source,
+ const ComputationDataHandle& init_value, const Computation& scatter);
+
+ // Enqueues an abs instruction onto the computation.
+ ComputationDataHandle Abs(const ComputationDataHandle& operand);
+
+ // Enqueues an exp instruction onto the computation.
+ ComputationDataHandle Exp(const ComputationDataHandle& operand);
+
+ // Enqueues a floor instruction onto the computation.
+ ComputationDataHandle Floor(const ComputationDataHandle& operand);
+
+ // Enqueues a ceil instruction onto the computation.
+ ComputationDataHandle Ceil(const ComputationDataHandle& operand);
+
+ // Enqueues an log instruction (natural logarithm) onto the computation.
+ ComputationDataHandle Log(const ComputationDataHandle& operand);
+
+ // Enqueues a sign instruction onto the computation.
+ ComputationDataHandle Sign(const ComputationDataHandle& operand);
+
+ // Enqueues a tanh instruction onto the computation.
+ ComputationDataHandle Tanh(const ComputationDataHandle& operand);
+
+ // Enqueues a float32 sqrt instruction onto the computation.
+ // (float32 is specified as there is an implicit float32 0.5f constant
+ // exponent).
+ ComputationDataHandle SqrtF32(const ComputationDataHandle& operand);
+
+ // Enqueues a float32 square instruction onto the computation.
+ // (float32 is specified as there is an implicit float32 2.0f constant
+ // exponent).
+ ComputationDataHandle SquareF32(const ComputationDataHandle& operand);
+
+ // Enqueues a lhs^rhs computation onto the computation.
+ ComputationDataHandle Pow(const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs);
+
+ // Enqueues a convert instruction onto the computation that changes the
+ // element type of the operand array to primitive_type.
+ ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand,
+ PrimitiveType new_element_type);
+
+ // Enqueues a float32 reciprocal instruction onto the computation.
+ // (float32 is specified as there is an implicit float32 -1.0f constant
+ // exponent).
+ //
+ // TODO(leary) axe F32 suffix, can be determined by reflecting on the shape of
+ // the operand.
+ ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
+
+ // Enqueues a negate instruction onto the computation.
+ ComputationDataHandle Neg(const ComputationDataHandle& operand);
+
+ // Enqueues a transpose instruction onto the computation.
+ ComputationDataHandle Transpose(
+ const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+
+ // Enqueues a reverse instruction onto the computation. The order of the
+ // elements in the given dimensions is reversed (i.e., the element at index i
+ // is moved to index dimension_size - 1 - i).
+ ComputationDataHandle Rev(const ComputationDataHandle& operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Enqueues a sort (as increasing order) instruction onto the computation.
+ ComputationDataHandle Sort(const ComputationDataHandle& operand);
+
+ // Enqueues a clamp instruction onto the computation.
+ ComputationDataHandle Clamp(const ComputationDataHandle& min,
+ const ComputationDataHandle& operand,
+ const ComputationDataHandle& max);
+
+ // Enqueues a map instruction onto the computation.
+ ComputationDataHandle Map(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands = {});
+
+ // Enqueues a N(mu, sigma) random number generation instruction onto the
+ // computation.
+ ComputationDataHandle RngNormal(const ComputationDataHandle& mu,
+ const ComputationDataHandle& sigma,
+ const Shape& shape);
+
+ // Enqueues a U(a, b) random number generation instruction onto the
+ // computation.
+ ComputationDataHandle RngUniform(const ComputationDataHandle& a,
+ const ComputationDataHandle& b,
+ const Shape& shape);
+
+ // Enqueues a B(1, p) random number generation instruction onto the
+ // computation.
+ ComputationDataHandle RngBernoulli(const ComputationDataHandle& mean,
+ const Shape& shape);
+
+ // Enqueues a while node onto the computation.
+ ComputationDataHandle While(const Computation& condition,
+ const Computation& body,
+ const ComputationDataHandle& init);
+
+ // Enqueues a Send node onto the computation, to send the given operand to
+ // a Recv instruction that shares the same channel handle.
+ void Send(const ComputationDataHandle& operand, const ChannelHandle& handle);
+
+ // Enqueues a Recv node onto the computation. The data comes from a Send
+ // instruction that shares the same channel handle and its shape must
+ // be the same as the given shape.
+ ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle);
+
+ // Returns true if 'operand' is a compile-time constant. A compile-time
+ // constant does not depend on parameters, or on stateful operators such
+ // as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests
+ // whether a computation is a compile-time constant without evaluating the
+ // computation.
+ StatusOr<bool> IsConstant(const ComputationDataHandle& operand);
+
+ // Computes the value of a constant indicated by a
+ // ComputationDataHandle.
+ //
+ // The handle must be from the computation currently being built -
+ // i.e., returned from this builder with no intervening call to
+ // Build(). This happens to currently work regardless of that, but
+ // that may stop working at any time.
+ //
+ // The handle must represent a constant value, which in this case
+ // means that it must not statically depend on a parameter to the
+ // computation that is being built.
+ //
+ // `IsConstant` can be used to test whether a computation is a compile-time
+ // constant without evaluation it. `ComputeConstant` only succeeds for
+ // computations where `IsConstant` returns true.
+ //
+ // This functionality can be useful when translating a computation
+ // into XLA where something that looked dynamic is required by
+ // XLA to be specified as a constant. E.g. the source
+ // computation (outside of XLA) may include a dynamic
+ // computation of the shape of something and ComputeConstant lets
+ // you determine what the value of that computation is in the case
+ // where the value can be determined at compile time.
+ //
+ // If output_layout is non-null, then the output of the computation
+ // will be stored using that layout.
+ StatusOr<std::unique_ptr<GlobalData>> ComputeConstant(
+ const ComputationDataHandle& handle,
+ const Layout* output_layout = nullptr);
+
+ // Returns a new ComputationBuilder whose resultant Computation is used only
+ // by this ComputationBuilder. The sub-ComputationBuilder has the same
+ // die_immediately_on_error behavior as the parent.
+ std::unique_ptr<ComputationBuilder> CreateSubBuilder(
+ const string& computation_name);
+
+ // Modifies the computation being built so that executions of it
+ // will return the value associated with operand, rather than the
+ // last expression enqueued on the ComputationBuilder. Any subsequent
+ // operations added to the ComputationBuilder will not have any effect unless
+ // SetReturnValue is called again.
+ Status SetReturnValue(const ComputationDataHandle& operand);
+
+ // Builds the computation with the requested operations, or returns a non-ok
+ // status.
+ StatusOr<Computation> Build();
+
+ // Builds the computation with the requested operations, or notes an error in
+ // the parent ComputationBuilder and returns an empty computation if building
+ // failed. This function is intended to be used where the returned
+ // Computation is only used by the parent ComputationBuilder and hence further
+ // operation on the returned Computation will simply be error'ed out if an
+ // error occurred while building this computation. If the built computation is
+ // to be used by a ComputationBuilder other than the parent ComputationBuilder
+ // then Build() should be used instead.
+ Computation BuildAndNoteError();
+
+ private:
+ using PopulateLiteral = std::function<void(Literal*)>;
+
+ // Limited checking of convolution parameters. Returns false on
+ // error.
+ bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // The parent ComputationBuilder of a sub-ComputationBuilder. The
+ // parent_builder_ will be the nullptr if not a sub-ComputationBuilder.
+ ComputationBuilder* parent_builder_{nullptr};
+
+ // Helper function for creating a Window proto from user-supplied
+ // data. Returns true if the user-supplied data was valid.
+ bool MakeWindow(tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
+ tensorflow::gtl::ArraySlice<int64> lhs_dilation,
+ tensorflow::gtl::ArraySlice<int64> rhs_dilation,
+ Window* window);
+
+ // Internal helper method that makes a request for a constant operation -- the
+ // provided function is used to populate the literal before sending the
+ // request.
+ ComputationDataHandle ConstantOp(const PopulateLiteral& populate);
+
+ // Internal helper method that does the building for an arbitrary unary op.
+ ComputationDataHandle UnaryOp(UnaryOperation binop,
+ const ComputationDataHandle& operand);
+
+ // Internal helper method that does the building for an arbitrary binary op.
+ // broadcast_dimensions specifies which dimensions to use for broadcasting
+ // when the operation is between tensors of different ranks.
+ ComputationDataHandle BinaryOp(
+ BinaryOperation binop, const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Internal helper method that does the building for an arbitrary ternary op.
+ ComputationDataHandle TernaryOp(TernaryOperation triop,
+ const ComputationDataHandle& lhs,
+ const ComputationDataHandle& rhs,
+ const ComputationDataHandle& ehs);
+
+ // Internal helper method that does the building for a random number generator
+ // of a given distribution with an explicitly specified shape.
+ ComputationDataHandle RngOp(
+ RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
+ const Shape& shape);
+
+ // Populates computation_ with a valid object or returns a failing status.
+ // This is used before any given operation is enqueued.
+ Status PrepareComputation();
+
+ // Helper function for parsing a method response and either returning the
+ // output computation data handle (on success) or a vacuous computation data
+ // handle (on failure).
+ ComputationDataHandle ParseOpResponse(const Status& status,
+ OpResponse* response);
+
+ // Notes that the error occurred by:
+ // * storing it internally and capturing a backtrace if it's the first error
+ // (this deferred value will be produced on the call to Build())
+ // * dying if die_immediately_on_error_ is true
+ void NoteError(const Status& error);
+
+ string name_; // Name to use for the built computation.
+
+ // The first error encountered while building the computation.
+ // This is OK until the first error is encountered.
+ Status first_error_;
+
+ // The saved stack trace from the point at which the first error occurred.
+ tensorflow::SavedStackTrace first_error_backtrace_;
+
+ // The computation that operations are enqueued onto.
+ Computation computation_;
+
+ // The client that the computation is created in. Not owned.
+ Client* client_;
+
+ // Mode bit that indicates whether to die when a first error is encountered.
+ bool die_immediately_on_error_{false};
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder);
+};
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) {
+ return ConstantOp(
+ [value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR1(
+ tensorflow::gtl::ArraySlice<NativeT> values) {
+ return ConstantOp([&values](Literal* literal) {
+ LiteralUtil::PopulateR1(values, literal);
+ });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR1(int64 length,
+ NativeT value) {
+ return ConstantOp([length, value](Literal* literal) {
+ LiteralUtil::PopulateWithValue(value, {length}, literal);
+ });
+}
+
+inline ComputationDataHandle ComputationBuilder::ConstantR1(
+ const tensorflow::core::Bitmap& values) {
+ return ConstantOp([&values](Literal* literal) {
+ LiteralUtil::PopulateR1(values, literal);
+ });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ return ConstantOp([&values](Literal* literal) {
+ LiteralUtil::PopulateR2(values, literal);
+ });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
+ const Array2D<NativeT>& values) {
+ return ConstantOp([&values](Literal* literal) {
+ LiteralUtil::PopulateR2FromArray2D(values, literal);
+ });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
+ const Array3D<NativeT>& values) {
+ return ConstantOp([&values](Literal* literal) {
+ LiteralUtil::PopulateR3FromArray3D(values, literal);
+ });
+}
+
+template <typename NativeT>
+ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
+ const Array4D<NativeT>& values) {
+ return ConstantOp([&values](Literal* literal) {
+ LiteralUtil::PopulateR4FromArray4D(values, literal);
+ });
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_
diff --git a/tensorflow/compiler/xla/client/global_data.cc b/tensorflow/compiler/xla/client/global_data.cc
new file mode 100644
index 0000000000..be706f7d23
--- /dev/null
+++ b/tensorflow/compiler/xla/client/global_data.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/global_data.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+GlobalData::GlobalData(ServiceInterface* parent, GlobalDataHandle handle)
+ : handle_(handle), parent_(parent) {}
+
+GlobalData::~GlobalData() {
+ UnregisterRequest request;
+ *request.mutable_data() = handle_;
+ UnregisterResponse response;
+ VLOG(1) << "requesting to unregister " << handle_.ShortDebugString();
+ tensorflow::Status s = parent_->Unregister(&request, &response);
+ VLOG(1) << "done with request";
+
+ if (!s.ok()) {
+ LOG(WARNING) << "failed to unregister " << handle_.ShortDebugString()
+ << "; continuing anyway...";
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/global_data.h b/tensorflow/compiler/xla/client/global_data.h
new file mode 100644
index 0000000000..eb11d91034
--- /dev/null
+++ b/tensorflow/compiler/xla/client/global_data.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_
+
+#include "tensorflow/compiler/xla/service_interface.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Wraps a GlobalDataHandle with a lifetime.
+class GlobalData {
+ public:
+ // Gives ownership of the global data handle to this object.
+ GlobalData(ServiceInterface* parent, GlobalDataHandle handle);
+
+ // Unregisters the wrapped handle.
+ ~GlobalData();
+
+ const GlobalDataHandle& handle() const { return handle_; }
+
+ private:
+ GlobalDataHandle handle_; // Handle being wrapped.
+ ServiceInterface* parent_; // Service used to unregister handle_.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GlobalData);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_GLOBAL_DATA_H_
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
new file mode 100644
index 0000000000..e185beaedd
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -0,0 +1,60 @@
+# Common computation builders for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow/compiler/xla/client:friends"])
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "arithmetic",
+ srcs = ["arithmetic.cc"],
+ hdrs = ["arithmetic.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ ],
+)
+
+cc_library(
+ name = "testing",
+ srcs = ["testing.cc"],
+ hdrs = ["testing.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/core:lib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc
new file mode 100644
index 0000000000..31efd8ee64
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+Computation CreateScalarAddComputation(PrimitiveType type,
+ ComputationBuilder* builder) {
+ const Shape scalar = ShapeUtil::MakeShape(type, {});
+ auto b = builder->CreateSubBuilder("add_" + PrimitiveType_Name(type));
+ auto lhs = b->Parameter(0, scalar, "lhs");
+ auto rhs = b->Parameter(1, scalar, "rhs");
+ b->Add(lhs, rhs);
+ return b->BuildAndNoteError();
+}
+
+Computation CreateScalarGeComputation(PrimitiveType type,
+ ComputationBuilder* builder) {
+ const Shape scalar = ShapeUtil::MakeShape(type, {});
+ auto b = builder->CreateSubBuilder("ge_" + PrimitiveType_Name(type));
+ auto lhs = b->Parameter(0, scalar, "lhs");
+ auto rhs = b->Parameter(1, scalar, "rhs");
+ b->Ge(lhs, rhs);
+ return b->BuildAndNoteError();
+}
+
+Computation CreateScalarMaxComputation(PrimitiveType type,
+ ComputationBuilder* builder) {
+ const Shape scalar = ShapeUtil::MakeShape(type, {});
+ auto b = builder->CreateSubBuilder("max_" + PrimitiveType_Name(type));
+ auto lhs = b->Parameter(0, scalar, "lhs");
+ auto rhs = b->Parameter(1, scalar, "rhs");
+ b->Max(lhs, rhs);
+ return b->BuildAndNoteError();
+}
+
+Computation CreateScalarMinComputation(PrimitiveType type,
+ ComputationBuilder* builder) {
+ const Shape scalar = ShapeUtil::MakeShape(type, {});
+ auto b = builder->CreateSubBuilder("min_" + PrimitiveType_Name(type));
+ auto lhs = b->Parameter(0, scalar, "lhs");
+ auto rhs = b->Parameter(1, scalar, "rhs");
+ b->Min(lhs, rhs);
+ return b->BuildAndNoteError();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h
new file mode 100644
index 0000000000..57fe7d7462
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/arithmetic.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Creates a scalar add computation and returns it.
+Computation CreateScalarAddComputation(PrimitiveType type,
+ ComputationBuilder* builder);
+
+// Creates a scalar ge computation and returns it.
+Computation CreateScalarGeComputation(PrimitiveType type,
+ ComputationBuilder* builder);
+
+// Creates a scalar max computation and returns it.
+Computation CreateScalarMaxComputation(PrimitiveType type,
+ ComputationBuilder* builder);
+
+// Creates a scalar min computation and returns it.
+Computation CreateScalarMinComputation(PrimitiveType type,
+ ComputationBuilder* builder);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_ARITHMETIC_H_
diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc
new file mode 100644
index 0000000000..004f3815d2
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/testing.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/testing.h"
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
+ Client* client) {
+ ComputationBuilder b(
+ client,
+ tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape)));
+ // TODO(b/26811613): Replace this when RNG is supported on all backends.
+ b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())),
+ AsInt64Slice(shape.dimensions()));
+ Computation computation = b.Build().ConsumeValueOrDie();
+ return client->Execute(computation, /*arguments=*/{}, &shape)
+ .ConsumeValueOrDie();
+}
+
+std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
+ const Computation& computation, Client* client) {
+ auto program_shape =
+ client->GetComputationShape(computation).ConsumeValueOrDie();
+
+ // For every (unbound) parameter that the computation wants, we manufacture
+ // some arbitrary data so that we can invoke the computation.
+ std::vector<std::unique_ptr<GlobalData>> fake_arguments;
+ for (const Shape& parameter : program_shape->parameters()) {
+ fake_arguments.push_back(MakeFakeDataOrDie(parameter, client));
+ }
+
+ return fake_arguments;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/testing.h b/tensorflow/compiler/xla/client/lib/testing.h
new file mode 100644
index 0000000000..7e640d1307
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/testing.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// Generates fake data of the given shape on the device or dies. The fake data
+// is created by performing a computation on the device rather than transferring
+// data from the host to the device.
+std::unique_ptr<GlobalData> MakeFakeDataOrDie(const Shape& shape,
+ Client* client);
+
+// Returns vector of GlobalData handles of fake data (created using
+// MakeFakeDataOrDie) that are correctly shaped arguments for the given
+// computation.
+std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
+ const Computation& computation, Client* client);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_TESTING_H_
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
new file mode 100644
index 0000000000..148c033eaa
--- /dev/null
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -0,0 +1,371 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+
+#include <utility>
+
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+ExecutableBuildOptions& ExecutableBuildOptions::set_platform(
+ perftools::gputools::Platform* platform) {
+ platform_ = platform;
+ return *this;
+}
+
+perftools::gputools::Platform* ExecutableBuildOptions::platform() const {
+ return platform_;
+}
+
+ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal(
+ int device_ordinal) {
+ device_ordinal_ = device_ordinal;
+ return *this;
+}
+
+int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; }
+
+ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout(
+ const Shape& shape_with_layout) {
+ result_layout_set_ = true;
+ result_layout_ = shape_with_layout;
+ return *this;
+}
+
+const Shape* ExecutableBuildOptions::result_layout() const {
+ return result_layout_set_ ? &result_layout_ : nullptr;
+}
+
+ExecutableBuildOptions& ExecutableBuildOptions::set_has_hybrid_result(
+ bool has_hybrid_result) {
+ has_hybrid_result_ = has_hybrid_result;
+ return *this;
+}
+
+bool ExecutableBuildOptions::has_hybrid_result() const {
+ return has_hybrid_result_;
+}
+
+namespace {
+
+// Convenience class which holds an acquired stream from the backend and
+// automatically releases it when destructed.
+class StreamManager {
+ public:
+ static StatusOr<std::unique_ptr<StreamManager>> AcquireStream(
+ Backend* backend, int device_ordinal) {
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ backend->stream_executor(device_ordinal == -1
+ ? backend->default_device_ordinal()
+ : device_ordinal));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
+ backend->AcquireStream(executor));
+ return WrapUnique(new StreamManager(backend, std::move(stream)));
+ }
+
+ ~StreamManager() { backend_->ReleaseStream(std::move(stream_)); }
+
+ se::Stream* stream() const { return stream_.get(); }
+
+ private:
+ StreamManager(Backend* backend, std::unique_ptr<se::Stream> stream)
+ : backend_(backend), stream_(std::move(stream)) {}
+
+ Backend* backend_;
+ std::unique_ptr<se::Stream> stream_;
+};
+
+} // namespace
+
+LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
+ Backend* backend, int device_ordinal,
+ const ExecutableBuildOptions& build_options)
+ : executable_(std::move(executable)),
+ backend_(backend),
+ build_device_ordinal_(device_ordinal),
+ build_options_(build_options) {}
+
+tensorflow::Status LocalExecutable::ValidateExecutionOptions(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const ExecutableRunOptions& options) {
+ const ComputationLayout& computation_layout =
+ executable_->module_config().entry_computation_layout();
+
+ // Check argument number, shapes, and layouts.
+ if (arguments.size() != computation_layout.parameter_count()) {
+ return InvalidArgument(
+ "invalid number of arguments for computation: expected %d, got %zu",
+ computation_layout.parameter_count(), arguments.size());
+ }
+ for (int i = 0; i < arguments.size(); ++i) {
+ if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ arguments[i]->shape())) {
+ return InvalidArgument(
+ "argument does not match shape or layout of computation parameter "
+ "%d: expected %s, got %s",
+ i,
+ ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape())
+ .c_str(),
+ ShapeUtil::HumanString(arguments[i]->shape()).c_str());
+ }
+ }
+
+ if (options.stream() != nullptr) {
+ if (!options.stream()->ok()) {
+ return InvalidArgument("stream is uninitialized or in an error state");
+ }
+
+ // Check stream matches service platform.
+ const se::Platform* stream_platform =
+ options.stream()->parent()->platform();
+ if (stream_platform != backend_->platform()) {
+ return InvalidArgument(
+ "stream is for platform %s, but service targets platform %s",
+ stream_platform->Name().c_str(),
+ backend_->platform()->Name().c_str());
+ }
+
+ // Cannot specify device_ordinal with a stream. The stream determines these
+ // values.
+ if (options.device_ordinal() != -1) {
+ return InvalidArgument(
+ "cannot set both device ordinal and stream options in "
+ "ExecutableRunOptions; the stream determines the device ordinal");
+ }
+ }
+
+ // Verify that the device the executable was built for is equivalent to the
+ // device it will run on.
+ int run_device_ordinal = options.device_ordinal() == -1
+ ? backend_->default_device_ordinal()
+ : options.device_ordinal();
+ TF_ASSIGN_OR_RETURN(
+ bool devices_equivalent,
+ backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_));
+ if (!devices_equivalent) {
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor,
+ backend_->stream_executor(run_device_ordinal));
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor,
+ backend_->stream_executor(build_device_ordinal_));
+ return InvalidArgument(
+ "executable is built for device %s of type \"%s\"; cannot run it on "
+ "device %s of type \"%s\"",
+ backend_->device_name(build_device_ordinal_).c_str(),
+ build_executor->GetDeviceDescription().name().c_str(),
+ backend_->device_name(run_device_ordinal).c_str(),
+ run_executor->GetDeviceDescription().name().c_str());
+ }
+
+ return tensorflow::Status::OK();
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const ExecutableRunOptions& options) {
+ TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options));
+
+ ExecutableRunOptions actual_options = options;
+ std::unique_ptr<StreamManager> acquired_stream;
+ if (options.stream() == nullptr) {
+ TF_ASSIGN_OR_RETURN(
+ acquired_stream,
+ StreamManager::AcquireStream(backend_, options.device_ordinal()));
+ actual_options.set_stream(acquired_stream->stream());
+ }
+ if (options.allocator() == nullptr) {
+ actual_options.set_allocator(backend_->memory_allocator());
+ }
+
+ if (executable_->dumping()) {
+ return ExecuteAndDump(&actual_options, arguments);
+ }
+ return executable_->ExecuteOnStream(&actual_options, arguments,
+ /*hlo_execution_profile=*/nullptr);
+}
+
+tensorflow::Status LocalExecutable::Run(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const ExecutableRunOptions& options, ShapedBuffer* result) {
+ const ComputationLayout& computation_layout =
+ executable_->module_config().entry_computation_layout();
+ TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options));
+
+ if (!computation_layout.result_layout().MatchesLayoutInShape(
+ result->shape())) {
+ return InvalidArgument(
+ "result buffer does not match shape or layout of computation result: "
+ "expected %s, got %s",
+ ShapeUtil::HumanString(computation_layout.result_layout().shape())
+ .c_str(),
+ ShapeUtil::HumanString(result->shape()).c_str());
+ }
+
+ ExecutableRunOptions actual_options = options;
+ std::unique_ptr<StreamManager> acquired_stream;
+ if (options.stream() == nullptr) {
+ TF_ASSIGN_OR_RETURN(
+ acquired_stream,
+ StreamManager::AcquireStream(backend_, options.device_ordinal()));
+ actual_options.set_stream(acquired_stream->stream());
+ }
+ if (options.allocator() == nullptr) {
+ actual_options.set_allocator(backend_->memory_allocator());
+ }
+
+ if (executable_->dumping()) {
+ return Unimplemented("dumping execution not supported on this path");
+ }
+ return executable_->ExecuteOnStream(&actual_options, arguments, result,
+ /*hlo_execution_profile=*/nullptr);
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::ExecuteAndDump(
+ const ExecutableRunOptions* run_options,
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ executable_->session_module()->set_execution_platform(
+ backend_->platform()->Name());
+ TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module()));
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<ShapedBuffer> result,
+ executable_->ExecuteOnStream(run_options, arguments,
+ /*hlo_execution_profile=*/nullptr));
+ TF_RETURN_IF_ERROR(RecordResult(result.get(), executable_->session_module()));
+ TF_RETURN_IF_ERROR(executable_->DumpSessionModule());
+ return std::move(result);
+}
+
+tensorflow::Status LocalExecutable::RecordArguments(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ SessionModule* session_module) {
+ session_module->clear_arguments();
+ for (const ShapedBuffer* argument : arguments) {
+ TF_RETURN_IF_ERROR(
+ LiteralFromShapedBuffer(*argument, session_module->add_arguments()));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status LocalExecutable::RecordResult(
+ const ShapedBuffer* result, SessionModule* session_module) {
+ session_module->clear_result();
+ return LiteralFromShapedBuffer(*result, session_module->mutable_result());
+}
+
+tensorflow::Status LocalExecutable::LiteralFromShapedBuffer(
+ const ShapedBuffer& shaped_buffer, Literal* literal) {
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ backend_->stream_executor(shaped_buffer.device_ordinal()));
+ return backend_->transfer_manager()->TransferLiteralFromDevice(
+ executor, shaped_buffer.buffer({}), shaped_buffer.shape(),
+ shaped_buffer.shape(), literal);
+}
+
+StatusOr<std::unique_ptr<GlobalData>> LocalClient::AllocateBufferOnDevice(
+ const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) {
+ TF_ASSIGN_OR_RETURN(GlobalDataHandle handle,
+ local_service_->AllocateBufferOnDevice(
+ shape, device_ordinal, allocate_space_for_deep_copy));
+ return std::unique_ptr<GlobalData>(new GlobalData(local_service_, handle));
+}
+
+tensorflow::Status LocalClient::ResolveArguments(
+ const tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
+ int device_ordinal,
+ std::vector<perftools::gputools::DeviceMemoryBase>* argument_ptrs) {
+ return local_service_->ResolveArguments(arguments, device_ordinal,
+ argument_ptrs);
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> LocalClient::ExecuteLocally(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options) {
+ return local_service_->ExecuteLocally(computation.handle(), arguments,
+ options);
+}
+
+tensorflow::Status LocalClient::ExecuteLocally(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result) {
+ return local_service_->ExecuteLocally(computation.handle(), arguments,
+ options, result);
+}
+
+StatusOr<std::unique_ptr<AotCompilationResult>> LocalClient::CompileAheadOfTime(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const Shape& result_layout, const AotCompilationOptions& options) {
+ return local_service_->CompileAheadOfTime(
+ computation.handle(), argument_layouts, result_layout, options);
+}
+
+int64 LocalClient::PointerSizeForTriple(tensorflow::StringPiece target_triple) {
+ llvm::Triple triple(
+ llvm::Triple::normalize(llvm_ir::AsStringRef(target_triple)));
+ if (triple.isArch64Bit()) {
+ return 8;
+ } else if (triple.isArch32Bit()) {
+ return 4;
+ } else {
+ CHECK(triple.isArch16Bit());
+ return 2;
+ }
+}
+
+se::Platform* LocalClient::platform() const {
+ return local_service_->backend().platform();
+}
+
+int LocalClient::device_count() const {
+ return local_service_->backend().device_count();
+}
+
+bool LocalClient::device_ordinal_supported(int device_ordinal) const {
+ return local_service_->backend().device_ordinal_supported(device_ordinal);
+}
+
+int LocalClient::default_device_ordinal() const {
+ return local_service_->backend().default_device_ordinal();
+}
+
+StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& options) {
+ int device_ordinal = options.device_ordinal() == -1
+ ? default_device_ordinal()
+ : options.device_ordinal();
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ local_service_->CompileExecutable(computation.handle(), argument_layouts,
+ options.result_layout(), device_ordinal,
+ options.has_hybrid_result()));
+ return WrapUnique(new LocalExecutable(std::move(executable),
+ local_service_->mutable_backend(),
+ device_ordinal, options));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
new file mode 100644
index 0000000000..1d6243a3b6
--- /dev/null
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -0,0 +1,263 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Class containing options for building an LocalExecutable with
+// LocalClient::Compile.
+class ExecutableBuildOptions {
+ public:
+ // If set, this is the platform to build the computation for. This must match
+ // the underlying platform of the service. A value of nullptr indicates the
+ // option has not been set.
+ //
+ // TODO(b/28616830): Support multiple platforms.
+ ExecutableBuildOptions& set_platform(perftools::gputools::Platform* platform);
+ perftools::gputools::Platform* platform() const;
+
+ // If set, this is the device to build the computation for. Valid
+ // device_ordinal values are: 0 to # of devices - 1. These values are
+ // identical to the device ordinal values used by StreamExecutor. The built
+ // executable will be executable on any device equivalent to the specified
+ // device as determined by Backend::devices_equivalent(). A value of -1
+ // indicates this option has not been set.
+ ExecutableBuildOptions& set_device_ordinal(int device_ordinal);
+ int device_ordinal() const;
+
+ // If set, this specifies the layout of the result of the computation. If not
+ // set, the service will chose the layout of the result. A Shape is used to
+ // store the layout to accomodate tuple result shapes. A value of nullptr
+ // indicates the option has not been set.
+ ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout);
+ const Shape* result_layout() const;
+
+ // If set, the executable will be built to output a hybrid
+ // ShapedBuffer with top-level tuple pointers in host memory and
+ // result buffers in device memory.
+ ExecutableBuildOptions& set_has_hybrid_result(bool has_hybrid_result);
+ bool has_hybrid_result() const;
+
+ private:
+ perftools::gputools::Platform* platform_ = nullptr;
+ int device_ordinal_ = -1;
+ Shape result_layout_;
+ bool result_layout_set_ = false;
+ bool has_hybrid_result_ = true;
+};
+
+class LocalExecutable {
+ public:
+ // Run the compiled computation with the given arguments and options and
+ // return the result.
+ StatusOr<std::unique_ptr<ShapedBuffer>> Run(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const ExecutableRunOptions& options);
+
+ // Overload which places the computation result in the given preallocated
+ // buffer.
+ tensorflow::Status Run(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const ExecutableRunOptions& options, ShapedBuffer* result);
+
+ // Return the layout (contained in a shape) of the result produced by the
+ // computation.
+ const Shape& result_layout() const {
+ return executable_->module_config()
+ .entry_computation_layout()
+ .result_layout()
+ .shape();
+ }
+
+ // Return the options used to build the executable.
+ const ExecutableBuildOptions& build_options() const { return build_options_; }
+
+ // Return the built executable.
+ Executable* executable() const { return executable_.get(); }
+
+ private:
+ // Only a local client can construct these objects.
+ friend class LocalClient;
+
+ // Constructor invoked by LocalClient.
+ LocalExecutable(std::unique_ptr<Executable> executable, Backend* backend,
+ int device_ordinal,
+ const ExecutableBuildOptions& build_options);
+
+ // Validates that the given arguments and options satisfy various constraints
+ // of the computation.
+ tensorflow::Status ValidateExecutionOptions(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const ExecutableRunOptions& options);
+
+ // Records the computation in a SessionModule proto with the arguments used to
+ // invoke it, and the result. Enabled by flag: --tla_dump_executions_to.
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteAndDump(
+ const ExecutableRunOptions* run_options,
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+
+ // Records the arguments used to invoke the computation in a SessionModule
+ // proto.
+ tensorflow::Status RecordArguments(
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ SessionModule* session_module);
+
+ // Records the result of the computation in a SessionModule proto.
+ tensorflow::Status RecordResult(const ShapedBuffer* result,
+ SessionModule* session_module);
+
+ // Copies the contents of a ShapedBuffer into a Literal proto.
+ tensorflow::Status LiteralFromShapedBuffer(const ShapedBuffer& shaped_buffer,
+ Literal* literal);
+
+ // Compiled computation.
+ std::unique_ptr<Executable> executable_;
+
+ // Execution backend.
+ Backend* backend_;
+
+ // The ordinal of the device which this executable was compiled for. The
+ // executable can run on all equivalent devices (as determined by
+ // Backend::devices_equivalent).
+ int build_device_ordinal_;
+
+ // Options used to build the executable.
+ const ExecutableBuildOptions& build_options_;
+};
+
+// An XLA service client object for use when the client and service run in
+// the same process.
+class LocalClient : public Client {
+ public:
+ explicit LocalClient(LocalService* service)
+ : Client(service), local_service_(service) {}
+
+ LocalClient(const LocalClient&) = delete;
+ void operator=(const LocalClient&) = delete;
+
+ // For an array of arguments held on the local service, validate
+ // that each is placed on the specified device_ordinal, and return
+ // the DeviceMemoryBase corresponding to each argument.
+ tensorflow::Status ResolveArguments(
+ const tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
+ int device_ordinal,
+ std::vector<perftools::gputools::DeviceMemoryBase>* argument_ptrs);
+
+ // Return a handle to a buffer large enough to hold shape, allocated
+ // on device_ordinal on the local service. If
+ // allocate_space_for_deep_copy, the buffer is large enough to hold
+ // all sub-buffers of a tuple shape, otherwise it is only as large
+ // as the top-level tuple pointer array.
+ StatusOr<std::unique_ptr<GlobalData>> AllocateBufferOnDevice(
+ const Shape& shape, int device_ordinal,
+ bool allocate_space_for_deep_copy);
+
+ // Executes the given computation with the given arguments and
+ // options. Arguments and result are "zero-copy", and are passed as pointers
+ // to device memory. See LocalExecuteOptions class comments for description of
+ // available options. The returned ShapedBuffer includes pointer(s) to device
+ // memory (DeviceMemoryBase) which are the caller's responsibility to
+ // deallocate. The layout of the result is chosen by the XLA service and
+ // should not be relied upon to be a specific value. If a specific result
+ // layout is needed, then the layout should be set in options.
+ //
+ // The arrays of arguments with different shapes or layouts are assumed not to
+ // alias.
+ //
+ // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use
+ // Compile and run the returned LocalExecutable.
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteLocally(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options);
+
+ // Overload of ExecuteLocally which writes the result into the given
+ // ShapedBuffer "result". Result is const because the ShapedBuffer data
+ // structure itself is not modified, only the buffers in device memory to
+ // which it refers.
+ //
+ // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use
+ // Compile and run the returned LocalExecutable.
+ tensorflow::Status ExecuteLocally(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result);
+
+ // Build and return a LocalExecutable object. The executable is compiled using
+ // the given argument layouts and options.
+ StatusOr<std::unique_ptr<LocalExecutable>> Compile(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const ExecutableBuildOptions& options);
+
+ // Compiles the computation for ahead-of-time execution. This is intended for
+ // use in static compilation. The |argument_layouts| parameter is used to
+ // inform the compiler of the expected layout for arguments while
+ // |result_layout| is used to signal the layout of the result. The |options|
+ // parameter is used to request which target the compiler should emit code
+ // for.
+ //
+ // TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
+ // own library.
+ StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
+ const Computation& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const Shape& result_layout, const AotCompilationOptions& options);
+
+ // Returns the size of a pointer in bytes for a given triple.
+ static int64 PointerSizeForTriple(tensorflow::StringPiece triple);
+
+ // Returns the platform that the underlying service targets.
+ perftools::gputools::Platform* platform() const;
+
+ // Returns the number of devices on the system of the service platform
+ // type. Not all devices may be supported by the service (see
+ // device_ordinal_supported method).
+ int device_count() const;
+
+ // Returns the default device ordinal that the service will run computations
+ // on if no device ordinal is specified in execute options.
+ int default_device_ordinal() const;
+
+ // Returns whether the device with the given ordinal can be used by the
+ // service to execute computations. Not all devices of a particular platform
+ // may be usable by the service (eg, a GPU with insufficient CUDA compute
+ // capability).
+ bool device_ordinal_supported(int device_ordinal) const;
+
+ private:
+ LocalService* local_service_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LOCAL_CLIENT_H_
diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc
new file mode 100644
index 0000000000..281fa10408
--- /dev/null
+++ b/tensorflow/compiler/xla/client/padding.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/padding.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+std::vector<std::pair<int64, int64>> MakePadding(
+ tensorflow::gtl::ArraySlice<int64> input_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
+ CHECK_EQ(input_dimensions.size(), window_dimensions.size());
+ CHECK_EQ(input_dimensions.size(), window_strides.size());
+ std::vector<std::pair<int64, int64>> low_high_padding;
+ switch (padding) {
+ case Padding::kValid:
+ low_high_padding.resize(window_dimensions.size(), {0, 0});
+ return low_high_padding;
+
+ case Padding::kSame:
+ for (int64 i = 0; i < input_dimensions.size(); ++i) {
+ int64 input_dimension = input_dimensions[i];
+ int64 window_dimension = window_dimensions[i];
+ int64 window_stride = window_strides[i];
+ // We follow the same convention as in Tensorflow, such that
+ // output dimension := ceil(input_dimension / window_stride).
+ // See tensorflow/tensorflow/python/ops/nn.py
+ // for the reference. See also tensorflow/core/kernels/ops_util.cc
+ // for the part where we avoid negative padding using max(0, x).
+ //
+ //
+ // For an odd sized window dimension 2N+1 with stride 1, the middle
+ // element is always inside the base area, so we can see it as N + 1 +
+ // N elements. In the example below, we have a kernel of size
+ // 2*3+1=7 so that the center element is 4 with 123 to the
+ // left and 567 to the right.
+ //
+ // base area: ------------------------
+ // kernel at left: 1234567
+ // kernel at right: 1234567
+ //
+ // We can see visually here that we need to pad the base area
+ // by 3 on each side:
+ //
+ // padded base area: 000------------------------000
+ //
+ // For an even number 2N, there are two options:
+ //
+ // *** Option A
+ //
+ // We view 2N as (N - 1) + 1 + N, so for N=3 we have 12 to the
+ // left, 3 is the center and 456 is to the right, like this:
+ //
+ // base area: ------------------------
+ // kernel at left: 123456
+ // kernel at right: 123456
+ // padded base area: 00------------------------000
+ //
+ // Note how we pad by one more to the right than to the left.
+ //
+ // *** Option B
+ //
+ // We view 2N as N + 1 + (N - 1), so for N=3 we have 123 to
+ // the left, 4 is the center and 56 is to the right, like
+ // this:
+ //
+ // base area: ------------------------
+ // kernel at left: 123456
+ // kernel at right: 123456
+ // padded base area: 000------------------------00
+ //
+ // The choice here is arbitrary. We choose option A as this is
+ // what DistBelief and Tensorflow do.
+ //
+ // When the stride is greater than 1, the output size is smaller than
+ // the input base size. The base area is padded such that the last
+ // window fully fits in the padded base area, and the padding amount is
+ // evenly divided between the left and the right (or 1 more on the right
+ // if odd size padding is required). The example below shows the
+ // required padding when the base size is 10, the kernel size is 5, and
+ // the stride is 3. In this example, the output size is 4.
+ //
+ // base area: ----------
+ // 1'st kernel: 12345
+ // 2'nd kernel: 12345
+ // 3'rd kernel: 12345
+ // 4'th kernel: 12345
+ // padded base area: 00----------00
+ int64 output_dimension =
+ tensorflow::MathUtil::CeilOfRatio(input_dimension, window_stride);
+ int64 padding_size =
+ std::max<int64>((output_dimension - 1) * window_stride +
+ window_dimension - input_dimension,
+ 0);
+ low_high_padding.emplace_back(
+ tensorflow::MathUtil::FloorOfRatio(padding_size, 2ll),
+ tensorflow::MathUtil::CeilOfRatio(padding_size, 2ll));
+ }
+ break;
+ }
+
+ return low_high_padding;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h
new file mode 100644
index 0000000000..dce2d87e8d
--- /dev/null
+++ b/tensorflow/compiler/xla/client/padding.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_PADDING_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_PADDING_H_
+
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+
+// Describes the padding applied for a windowed operation like
+// convolution, where a window is placed inside a base area.
+enum class Padding {
+ // Make the output have the same dimensions as the base area. For
+ // example, for a 3x3 base area and a 2x2 window, the output will be
+ // 3x3, so that requires padding the 3x3 base area to 4x4.
+ kSame,
+
+ // Use no padding. For example, for a 4x4 base area and a 2x2
+ // window, the output will be 3x3.
+ kValid,
+};
+
+// Returns the padding needed for the base area, given the base area dimensions,
+// window dimensions, strides, and the type of padding.
+//
+// If v is the returned vector, then for each dimension number i,
+// v[i].first is the padding to the left (i.e. in the direction of
+// lower indices) and v[i].second is the padding to the right (i.e. in
+// the direction of higher indices).
+//
+// Precondition: The number of dimensions (i.e., rank) in input_dimensions,
+// window_dimensions, and strides must match, which is equal to the number
+// of elements in the result vector.
+std::vector<std::pair<int64, int64>> MakePadding(
+ tensorflow::gtl::ArraySlice<int64> input_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> strides, Padding padding);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_PADDING_H_
diff --git a/tensorflow/compiler/xla/client/padding_test.cc b/tensorflow/compiler/xla/client/padding_test.cc
new file mode 100644
index 0000000000..deda5ce708
--- /dev/null
+++ b/tensorflow/compiler/xla/client/padding_test.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/padding.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+// Tests MakePadding utility function for various cases.
+class PaddingTest : public ::testing::Test {
+ protected:
+ // A convenience function to test padding for a single dimension.
+ std::pair<int64, int64> ComputePadding(int64 input_dimension,
+ int64 window_dimension,
+ int64 window_stride, Padding padding) {
+ return MakePadding({input_dimension}, {window_dimension}, {window_stride},
+ padding)[0];
+ }
+};
+
+TEST_F(PaddingTest, ValidPaddingWithStrideOne) {
+ const auto padding = ComputePadding(10, 5, 1, Padding::kValid);
+ EXPECT_EQ(padding.first, 0);
+ EXPECT_EQ(padding.second, 0);
+}
+
+TEST_F(PaddingTest, ValidPaddingWithStrideThree) {
+ const auto padding = ComputePadding(10, 5, 3, Padding::kValid);
+ EXPECT_EQ(padding.first, 0);
+ EXPECT_EQ(padding.second, 0);
+}
+
+TEST_F(PaddingTest, SamePaddingWithOddWindow) {
+ const auto padding = ComputePadding(10, 7, 1, Padding::kSame);
+ EXPECT_EQ(padding.first, 3);
+ EXPECT_EQ(padding.second, 3);
+}
+
+TEST_F(PaddingTest, SamePaddingWithEvenWindow) {
+ const auto padding = ComputePadding(10, 6, 1, Padding::kSame);
+ EXPECT_EQ(padding.first, 2);
+ EXPECT_EQ(padding.second, 3);
+}
+
+TEST_F(PaddingTest, SamePaddingWithOddWindowWithStride) {
+ const auto padding = ComputePadding(10, 7, 3, Padding::kSame);
+ EXPECT_EQ(padding.first, 3);
+ EXPECT_EQ(padding.second, 3);
+}
+
+TEST_F(PaddingTest, SamePaddingWithEvenWindowWithStride) {
+ const auto padding = ComputePadding(10, 6, 4, Padding::kSame);
+ EXPECT_EQ(padding.first, 2);
+ EXPECT_EQ(padding.second, 2);
+}
+
+TEST_F(PaddingTest, SamePaddingForWindowSizeOne) {
+ const auto padding = ComputePadding(10, 1, 1, Padding::kSame);
+ EXPECT_EQ(padding.first, 0);
+ EXPECT_EQ(padding.second, 0);
+}
+
+TEST_F(PaddingTest, SamePaddingForWindowLargerThanInput) {
+ const auto padding = ComputePadding(10, 20, 1, Padding::kSame);
+ EXPECT_EQ(padding.first, 9);
+ EXPECT_EQ(padding.second, 10);
+}
+
+// This used to trigger a case with negative padding.
+TEST_F(PaddingTest, NonNegativePadding) {
+ const auto padding = ComputePadding(4, 1, 2, Padding::kSame);
+ EXPECT_EQ(padding.first, 0);
+ EXPECT_EQ(padding.second, 0);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/device_util.h b/tensorflow/compiler/xla/device_util.h
new file mode 100644
index 0000000000..23a622b1ad
--- /dev/null
+++ b/tensorflow/compiler/xla/device_util.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utilities common between the client and server for working with
+// StreamExecutor devices.
+
+#ifndef TENSORFLOW_COMPILER_XLA_DEVICE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_DEVICE_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Returns a string that represents the device in terms of platform and ordinal;
+// e.g. the first CUDA device will be "cuda:0"
+string DeviceIdentifier(perftools::gputools::StreamExecutor* stream_exec) {
+ return tensorflow::strings::StrCat(stream_exec->platform()->Name(), ":",
+ stream_exec->device_ordinal());
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_DEVICE_UTIL_H_
diff --git a/tensorflow/compiler/xla/differential_set.h b/tensorflow/compiler/xla/differential_set.h
new file mode 100644
index 0000000000..9eae24ce30
--- /dev/null
+++ b/tensorflow/compiler/xla/differential_set.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_
+#define TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_
+
+#include <set>
+
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// In the base case, the differential set is just a set.
+// However, you can also point a differential set at another differential set to
+// use as a "parent". This makes a chain of sets, which each node in the chain
+// adds some number of elements to the "Contains" property.
+//
+// E.g. if the base set holds {1, 2}, you can create a derived set that holds
+// {3}, and the derived set will tell you it contains {1, 2, 3} whereas the base
+// will continue to tell you it holds only {1, 2}.
+template <typename T>
+class DifferentialSet {
+ public:
+ // Constructs a differential set capable of holding values. It may have an
+ // ancestor link, which would it into a chain of sets.
+ explicit DifferentialSet(const DifferentialSet* parent = nullptr)
+ : parent_(parent) {}
+
+ // Adds a value to be held directly by this set.
+ void Add(T value) { held_.insert(value); }
+
+ // Returns whether this set holds the given value, or any ancestor in the
+ // chain of sets.
+ bool Contains(T value) const {
+ return held_.find(value) != held_.end() ||
+ (parent_ != nullptr && parent_->Contains(value));
+ }
+
+ private:
+ // Values held directly by this node in the chain of sets.
+ std::set<T> held_;
+
+ // Parent node in the chain of sets.
+ const DifferentialSet* parent_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DifferentialSet);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_DIFFERENTIAL_SET_H_
diff --git a/tensorflow/compiler/xla/differential_set_test.cc b/tensorflow/compiler/xla/differential_set_test.cc
new file mode 100644
index 0000000000..dacbbcc1ad
--- /dev/null
+++ b/tensorflow/compiler/xla/differential_set_test.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/differential_set.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(DifferentialSetTest, TellsWhetherSetContainsSomethingHeld) {
+ DifferentialSet<int> set;
+ set.Add(1);
+ set.Add(2);
+ EXPECT_FALSE(set.Contains(3));
+ EXPECT_TRUE(set.Contains(1));
+ EXPECT_TRUE(set.Contains(2));
+ EXPECT_FALSE(set.Contains(0));
+}
+
+TEST(DifferentialSetTest, TellsWhetherSetContainsSomethingParentHolds) {
+ DifferentialSet<int> parent;
+ parent.Add(1);
+ DifferentialSet<int> child(&parent);
+ child.Add(2);
+
+ // Test properties of the child.
+ EXPECT_FALSE(child.Contains(3));
+ EXPECT_TRUE(child.Contains(1));
+ EXPECT_TRUE(child.Contains(2));
+ EXPECT_FALSE(child.Contains(0));
+
+ // Test properties of the parent.
+ EXPECT_TRUE(parent.Contains(1));
+ EXPECT_FALSE(parent.Contains(2));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc
new file mode 100644
index 0000000000..1c54fec97c
--- /dev/null
+++ b/tensorflow/compiler/xla/executable_run_options.cc
@@ -0,0 +1,70 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+
+namespace xla {
+
+ExecutableRunOptions& ExecutableRunOptions::set_device_ordinal(
+ int device_ordinal) {
+ device_ordinal_ = device_ordinal;
+ return *this;
+}
+
+int ExecutableRunOptions::device_ordinal() const { return device_ordinal_; }
+
+ExecutableRunOptions& ExecutableRunOptions::set_allocator(
+ DeviceMemoryAllocator* allocator) {
+ allocator_ = allocator;
+ return *this;
+}
+
+DeviceMemoryAllocator* ExecutableRunOptions::allocator() const {
+ return allocator_;
+}
+
+ExecutableRunOptions& ExecutableRunOptions::set_stream(
+ perftools::gputools::Stream* stream) {
+ stream_ = stream;
+ return *this;
+}
+
+perftools::gputools::Stream* ExecutableRunOptions::stream() const {
+ return stream_;
+}
+
+ExecutableRunOptions& ExecutableRunOptions::set_inter_op_thread_pool(
+ tensorflow::thread::ThreadPool* inter_op_thread_pool) {
+ inter_op_thread_pool_ = inter_op_thread_pool;
+ return *this;
+}
+
+tensorflow::thread::ThreadPool* ExecutableRunOptions::inter_op_thread_pool()
+ const {
+ return inter_op_thread_pool_;
+}
+
+ExecutableRunOptions& ExecutableRunOptions::set_intra_op_thread_pool(
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool) {
+ intra_op_thread_pool_ = intra_op_thread_pool;
+ return *this;
+}
+
+const Eigen::ThreadPoolDevice* ExecutableRunOptions::intra_op_thread_pool()
+ const {
+ return intra_op_thread_pool_;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h
new file mode 100644
index 0000000000..212fce9eab
--- /dev/null
+++ b/tensorflow/compiler/xla/executable_run_options.h
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
+#define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
+
+// Intentionally forward declared so that ExecutableRunOptions can be linked
+// into an XLA-compiled binary without having to link all of the pointed-to
+// objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
+// need to be linked).
+namespace perftools {
+namespace gputools {
+class Stream;
+class Platform;
+}
+}
+
+namespace tensorflow {
+namespace thread {
+class ThreadPool;
+}
+}
+
+namespace Eigen {
+struct ThreadPoolDevice;
+}
+
+namespace xla {
+
+class DeviceMemoryAllocator;
+
+// Class containing options for running a LocalExecutable.
+class ExecutableRunOptions {
+ public:
+ // Specifies the allocator to use during execution.
+ ExecutableRunOptions& set_allocator(DeviceMemoryAllocator* allocator);
+ DeviceMemoryAllocator* allocator() const;
+
+ // If set, this is the device to run the computation on. Valid device_ordinal
+ // values are: 0 to # of devices - 1. These values are identical to the device
+ // ordinal values used by StreamExecutor. The device must be of the same type
+ // as the executable was compiled for. A value of -1 indicates this option has
+ // not been set.
+ ExecutableRunOptions& set_device_ordinal(int device_ordinal);
+ int device_ordinal() const;
+
+ // If set, this is the stream to run the computation on. The platform of the
+ // stream must match the platform the executable was built for. A value of
+ // nullptr indicates the option has not been set.
+ ExecutableRunOptions& set_stream(perftools::gputools::Stream* stream);
+ perftools::gputools::Stream* stream() const;
+
+ // Sets the thread pool on which to run parallel CPU backend
+ // computations. Does not take ownership.
+ ExecutableRunOptions& set_inter_op_thread_pool(
+ tensorflow::thread::ThreadPool* inter_op_thread_pool);
+ tensorflow::thread::ThreadPool* inter_op_thread_pool() const;
+
+ // Sets the thread pool device on which to run Eigen subcomputations.
+ // Does not take ownership.
+ ExecutableRunOptions& set_intra_op_thread_pool(
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool);
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool() const;
+
+ private:
+ DeviceMemoryAllocator* allocator_ = nullptr;
+ int device_ordinal_ = -1;
+ perftools::gputools::Stream* stream_ = nullptr;
+ tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr;
+ const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc
new file mode 100644
index 0000000000..901fcd89ea
--- /dev/null
+++ b/tensorflow/compiler/xla/index_util.cc
@@ -0,0 +1,126 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/index_util.h"
+
+#include <algorithm>
+#include <string>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+/* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
+ const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ CHECK_EQ(shape.dimensions_size(), multi_index.size());
+ // Padding and nested layouts not supported yet.
+ CHECK_EQ(0, shape.layout().padded_dimensions_size());
+
+ for (int i = 0; i < multi_index.size(); ++i) {
+ CHECK_GE(multi_index[i], 0);
+ CHECK_LT(multi_index[i], shape.dimensions(i))
+ << "indexing beyond extent in dimension " << i << ":"
+ << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",")
+ << "\n\tshape: " << ShapeUtil::HumanString(shape);
+ }
+
+ // Let the array be sized like so for dimensions i from 0 to n-1:
+ //
+ // [D{n-1} x D{n-2} x .. x D{0}]
+ //
+ // Let the order of the dimensions in the minor_to_major field in
+ // Layout be:
+ //
+ // L(0), L(1), ... , L(n-1)
+ //
+ // where L(0) is the most-minor dimension and L(n-1) the most-major. The
+ // multidimensional index:
+ //
+ // [I{0}, I{1}, ... , I{n-1}]
+ //
+ // then corresponds to the following linear index:
+ //
+ // linear_index =
+ // ((( ... + I{L(2)}) * D{L(1)} + I{L(1)}) * D{L(0)} + I{L(0)}
+ //
+ // or equivalently:
+ //
+ // linear_index =
+ // I{L(n-1)} * (D{L(n-2)} * D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) +
+ // I{L(n-2)} * (D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) +
+ // I{L(n-3)} * (D{L(n-4)} * .... D{L(0)}) +
+ // ... +
+ // I{L(2)} * (D{L(1)} * D{L(0)}) +
+ // I{L(1)} * D{L(0)} +
+ // I{L(0)}
+ //
+ // We compute the linear index value by accumulating the terms above from
+ // I{L(0)} up to I{L(n-1)}. Scale accumulates the product term D{L(0}} *
+ // D{L(1)} * ...
+
+ // Scale factor holding the growing product of D{L(i)} terms.
+ int64 scale = 1;
+ int64 linear_index = 0;
+ for (auto dimension : shape.layout().minor_to_major()) {
+ linear_index += scale * multi_index[dimension];
+ scale *= shape.dimensions(dimension);
+ }
+ return linear_index;
+}
+
+/* static */ std::vector<int64> IndexUtil::LinearIndexToMultidimensionalIndex(
+ const Shape& shape, int64 linear_index) {
+ // Padding and nested layouts not supported yet.
+ CHECK_EQ(0, shape.layout().padded_dimensions_size());
+ CHECK_GE(linear_index, 0);
+ CHECK_LT(linear_index, ShapeUtil::ElementsIn(shape));
+
+ // The following formula computes each element of the multidimensional index
+ // (See comments in MultidimensionalIndexToLinearIndex for notation):
+ //
+ // I{L(0)} = linear_index % D{L(0)}
+ // I{L(1)} = (linear_index / D{L(0)}) % D{L(1)}
+ // I{L(2)} = (linear_index / (D{L(0)} * D{L(1)})) % D{L(2)}
+ // ...
+ std::vector<int64> multi_index(shape.dimensions_size());
+
+ // Accumulated product D{L(0)} * D{L(1)} * ...
+ int64 divisor = 1;
+ for (auto dimension : shape.layout().minor_to_major()) {
+ multi_index[dimension] =
+ (linear_index / divisor) % shape.dimensions(dimension);
+ divisor *= shape.dimensions(dimension);
+ }
+ return multi_index;
+}
+
+/* static */ bool IndexUtil::BumpIndices(const Shape& shape,
+ std::vector<int64>* indices) {
+ for (int64 dimno = indices->size() - 1; dimno >= 0; --dimno) {
+ int64 limit = shape.dimensions(dimno);
+ if ((*indices)[dimno] + 1 < limit) {
+ (*indices)[dimno]++;
+ std::fill(indices->begin() + dimno + 1, indices->end(), 0);
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h
new file mode 100644
index 0000000000..2d8753c3fe
--- /dev/null
+++ b/tensorflow/compiler/xla/index_util.h
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility functions related to layouts of Shapes.
+
+#ifndef TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Namespaced collection of (static) utilities related to indexing into
+// multidimensional arrays.
+class IndexUtil {
+ public:
+ // Converts a multidimensional index (eg {x, y, z}) into a linear index based
+ // on the shape and its layout. The first index in the multi_index is
+ // dimension 0.
+ static int64 MultidimensionalIndexToLinearIndex(
+ const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index);
+
+ // Coverts a linear index into multidimensional index (eg {x, y, z}) based on
+ // the shape and its layout. The first index in the returned multidimensional
+ // index is dimension 0.
+ static std::vector<int64> LinearIndexToMultidimensionalIndex(
+ const Shape& shape, int64 linear_index);
+
+ // Bumps a sequence of indices; e.g. {0,0,0,0} up by one index value; e.g. to
+ // {0,0,0,1}. This is akin to std::next_permutation. If the index hits a limit
+ // for the provided shape, the next most significant index is bumped, in a
+ // counting-up process.
+ //
+ // E.g. for shape f32[2,3]
+ // {0,0}=>{0,1}
+ // {0,1}=>{0,2}
+ // {0,2}=>{1,0}
+ // etc.
+ //
+ // This is useful for traversing the indices in a literal.
+ //
+ // Returns true iff the indices were successfully bumped; false if we've hit
+ // the limit where it can no longer be bumped in-bounds.
+ static bool BumpIndices(const Shape& shape, std::vector<int64>* indices);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_INDEX_UTIL_H_
diff --git a/tensorflow/compiler/xla/index_util_test.cc b/tensorflow/compiler/xla/index_util_test.cc
new file mode 100644
index 0000000000..85259b33f0
--- /dev/null
+++ b/tensorflow/compiler/xla/index_util_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/index_util.h"
+
+#include <initializer_list>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+void SetMinorToMajorLayout(Shape* shape,
+ std::initializer_list<int64> dimensions) {
+ shape->mutable_layout()->clear_minor_to_major();
+ for (auto dimension : dimensions) {
+ shape->mutable_layout()->add_minor_to_major(dimension);
+ }
+}
+
+TEST(IndexUtilTest, VectorIndexing) {
+ // Vectors are trivially laid out and the linear index should always be the
+ // same as the "multidimensional" index.
+ Shape vector_shape = ShapeUtil::MakeShape(F32, {100});
+ EXPECT_EQ(42,
+ IndexUtil::MultidimensionalIndexToLinearIndex(vector_shape, {42}));
+ std::vector<int64> multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(vector_shape, 42);
+ EXPECT_EQ(1, multi_index.size());
+ EXPECT_EQ(42, multi_index[0]);
+}
+
+TEST(IndexUtilTest, MatrixIndexingRowMajor) {
+ // Set layout to [0, 1]. That is, row major.
+ Shape matrix_shape_01 = ShapeUtil::MakeShape(F32, {10, 20});
+ SetMinorToMajorLayout(&matrix_shape_01, {0, 1});
+
+ // If index is {a, b} then linear index should be: a + b * 10
+ EXPECT_EQ(0, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01,
+ {0, 0}));
+ EXPECT_EQ(199, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01,
+ {9, 19}));
+ EXPECT_EQ(53, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_01,
+ {3, 5}));
+ EXPECT_EQ(std::vector<int64>({3, 5}),
+ IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_01, 53));
+}
+
+TEST(IndexUtilTest, MatrixIndexingColumnMajor) {
+ // Set layout to [1, 0]. That is, column major.
+ Shape matrix_shape_10 = ShapeUtil::MakeShape(F32, {10, 20});
+ SetMinorToMajorLayout(&matrix_shape_10, {1, 0});
+
+ // If index is {a, b} then linear index should be: a * 20 + b
+ EXPECT_EQ(0, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10,
+ {0, 0}));
+ EXPECT_EQ(199, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10,
+ {9, 19}));
+ EXPECT_EQ(65, IndexUtil::MultidimensionalIndexToLinearIndex(matrix_shape_10,
+ {3, 5}));
+ EXPECT_EQ(std::vector<int64>({3, 5}),
+ IndexUtil::LinearIndexToMultidimensionalIndex(matrix_shape_10, 65));
+}
+
+TEST(IndexUtilTest, ThreeDArrayIndexing210) {
+ // Set layout to [2, 1, 0]. That is, column major.
+ Shape shape_210 = ShapeUtil::MakeShape(F32, {10, 20, 30});
+ SetMinorToMajorLayout(&shape_210, {2, 1, 0});
+
+ // If index is {a, b, c} then linear index should be:
+ // a * 20 * 30 + b * 30 + c
+ EXPECT_EQ(1957, IndexUtil::MultidimensionalIndexToLinearIndex(shape_210,
+ {3, 5, 7}));
+ EXPECT_EQ(5277, IndexUtil::MultidimensionalIndexToLinearIndex(shape_210,
+ {8, 15, 27}));
+}
+
+TEST(IndexUtilTest, ThreeDArrayIndexing120) {
+ // Set layout to [1, 2, 0]
+ Shape shape_120 = ShapeUtil::MakeShape(F32, {10, 20, 30});
+ SetMinorToMajorLayout(&shape_120, {1, 2, 0});
+
+ // If index is {a, b, c} then linear index should be:
+ // a * 20 * 30 + b + c * 20
+ EXPECT_EQ(1945, IndexUtil::MultidimensionalIndexToLinearIndex(shape_120,
+ {3, 5, 7}));
+ EXPECT_EQ(5355, IndexUtil::MultidimensionalIndexToLinearIndex(shape_120,
+ {8, 15, 27}));
+}
+
+TEST(IndexUtilTest, FourDArrayIndexing3210) {
+ // Set layout to [3, 2, 1,0]. That is, column major.
+ Shape shape_3210 = ShapeUtil::MakeShape(F32, {10, 20, 30, 40});
+ SetMinorToMajorLayout(&shape_3210, {3, 2, 1, 0});
+
+ // If index is {a, b, c, d} then linear index should be:
+ // a * 20 * 30 * 40 + b * 30 * 40 + c * 40 + d
+ EXPECT_EQ(78289, IndexUtil::MultidimensionalIndexToLinearIndex(shape_3210,
+ {3, 5, 7, 9}));
+ EXPECT_EQ(211113, IndexUtil::MultidimensionalIndexToLinearIndex(
+ shape_3210, {8, 15, 27, 33}));
+}
+
+TEST(IndexUtilTest, LinearToMultiToLinear) {
+ // Verify that converting a linear index to a multidimensional index and back
+ // always returns the same value for different crazy shapes. Shape has
+ // 1440000000 elements. Inputs are randomly-ish selected.
+ std::vector<int64> linear_indexes = {0, 1439999999, 1145567336,
+ 43883404, 617295214, 1117613654};
+
+ std::vector<std::initializer_list<int64>> minor_to_major_orders;
+ minor_to_major_orders.push_back({6, 5, 4, 3, 2, 1, 0});
+ minor_to_major_orders.push_back({0, 1, 2, 3, 4, 5, 6});
+ minor_to_major_orders.push_back({4, 5, 1, 2, 6, 0, 3});
+
+ for (auto minor_to_major_order : minor_to_major_orders) {
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 20, 30, 40, 30, 20, 10});
+ SetMinorToMajorLayout(&shape, minor_to_major_order);
+ for (auto linear_index : linear_indexes) {
+ std::vector<int64> multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index);
+ EXPECT_EQ(linear_index, IndexUtil::MultidimensionalIndexToLinearIndex(
+ shape, multi_index));
+ }
+ }
+}
+
+TEST(IndexUtilTest, BumpIndices2x2) {
+ auto shape = ShapeUtil::MakeShape(S32, {2, 2});
+ std::vector<int64> indices = {0, 0};
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_MATCH(indices,
+ testing::VectorMatcher<int64>(std::vector<int64>{0, 1}));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_MATCH(indices,
+ testing::VectorMatcher<int64>(std::vector<int64>{1, 0}));
+ EXPECT_TRUE(IndexUtil::BumpIndices(shape, &indices));
+ EXPECT_MATCH(indices,
+ testing::VectorMatcher<int64>(std::vector<int64>{1, 1}));
+ EXPECT_FALSE(IndexUtil::BumpIndices(shape, &indices));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
new file mode 100644
index 0000000000..81eb717821
--- /dev/null
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -0,0 +1,363 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/layout_util.h"
+
+#include <stddef.h>
+#include <algorithm>
+#include <functional>
+#include <random>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+namespace {
+
+using DimensionOrder = legacy_flags::DefaultLayout::DimensionOrder;
+
+// Internal helper for GetDefaultLayoutForShape and SetToDefaultLayout. Sets
+// minor_to_major to the value that represents the default layout.
+void SetDefaultLayoutToContainer(
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+ minor_to_major) {
+ const int64 size = minor_to_major->size();
+ legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags();
+ auto default_layout = flags->xla_default_layout;
+ switch (default_layout.dimension_order) {
+ case DimensionOrder::kMajorToMinor:
+ for (int64 i = 0; i < size; ++i) {
+ minor_to_major->Set(i, size - 1 - i);
+ }
+ break;
+ case DimensionOrder::kMinorToMajor:
+ for (int64 i = 0; i < size; ++i) {
+ minor_to_major->Set(i, i);
+ }
+ break;
+ case DimensionOrder::kRandom:
+ for (int64 i = 0; i < size; ++i) {
+ minor_to_major->Set(i, i);
+ }
+ std::shuffle(
+ minor_to_major->begin(), minor_to_major->end(),
+ std::mt19937(default_layout.seed != 0 ? default_layout.seed
+ : std::random_device()()));
+ }
+}
+
+} // namespace
+
+/* static */ Layout LayoutUtil::MakeLayout(
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ Layout layout;
+ for (int64 dimension_number : minor_to_major) {
+ layout.add_minor_to_major(dimension_number);
+ }
+ return layout;
+}
+
+namespace {
+
+// Internal helper that creates a default layout for an array of the given rank.
+Layout CreateDefaultLayoutForRank(int64 rank) {
+ Layout layout;
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+ minor_to_major = layout.mutable_minor_to_major();
+ minor_to_major->Resize(rank, 0);
+ SetDefaultLayoutToContainer(minor_to_major);
+ return layout;
+}
+
+} // namespace
+
+/* static */ Layout LayoutUtil::GetDefaultLayoutForShape(const Shape& shape) {
+ // A Layout proto corresponds to a single array, not a tuple.
+ DCHECK(!ShapeUtil::IsTuple(shape));
+ return CreateDefaultLayoutForRank(shape.dimensions_size());
+}
+
+/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
+ return CreateDefaultLayoutForRank(2);
+}
+
+/* static */ Layout LayoutUtil::GetDefaultLayoutForR3() {
+ return CreateDefaultLayoutForRank(3);
+}
+
+/* static */ Layout LayoutUtil::GetDefaultLayoutForR4() {
+ return CreateDefaultLayoutForRank(4);
+}
+
+/* static */ void LayoutUtil::SetToDefaultLayout(Shape* shape) {
+ if (ShapeUtil::IsTuple(*shape)) {
+ // Tuple shape.
+ for (auto& element_shape : *shape->mutable_tuple_shapes()) {
+ SetToDefaultLayout(&element_shape);
+ }
+ } else {
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+ minor_to_major = shape->mutable_layout()->mutable_minor_to_major();
+ minor_to_major->Resize(shape->dimensions_size(), 0);
+ SetDefaultLayoutToContainer(minor_to_major);
+ }
+}
+
+/* static */ void LayoutUtil::SetToDefaultLayout(ProgramShape* program_shape) {
+ for (auto& parameter_shape : *program_shape->mutable_parameters()) {
+ LayoutUtil::SetToDefaultLayout(&parameter_shape);
+ }
+ LayoutUtil::SetToDefaultLayout(program_shape->mutable_result());
+}
+
+/* static */ tensorflow::Status LayoutUtil::ValidateLayoutInShape(
+ const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ // Tuple shape.
+ if (shape.has_layout()) {
+ return InvalidArgument("tuple should not have a layout field");
+ }
+ for (auto& element_shape : shape.tuple_shapes()) {
+ TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape));
+ }
+ return tensorflow::Status::OK();
+ } else if (ShapeUtil::Rank(shape) == 0 && !shape.has_layout()) {
+ // A scalar without a layout is ok.
+ return tensorflow::Status::OK();
+ } else {
+ // Array shape.
+ if (!shape.has_layout()) {
+ return InvalidArgument("shape does not have a layout");
+ }
+ return ValidateLayoutForShape(shape.layout(), shape);
+ }
+}
+
+/* static */ tensorflow::Status LayoutUtil::ValidateLayoutForShape(
+ const Layout& layout, const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ return InvalidArgument("a single Layout is not valid for tuple shapes");
+ }
+
+ if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) {
+ return InvalidArgument(
+ "layout minor_to_major field contains %d elements, "
+ "but shape is rank %lld: {%s}; shape: %s",
+ layout.minor_to_major_size(), ShapeUtil::Rank(shape),
+ tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(),
+ shape.ShortDebugString().c_str());
+ }
+
+ std::vector<bool> dimensions_in_layout(ShapeUtil::Rank(shape), false);
+ for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) {
+ int64 dim = layout.minor_to_major(i);
+ if (dim < 0 || dim >= ShapeUtil::Rank(shape)) {
+ return InvalidArgument(
+ "layout minor_to_major field has out-of-bounds value: %s",
+ HumanString(layout).c_str());
+ }
+ if (dimensions_in_layout[dim]) {
+ return InvalidArgument(
+ "layout minor_to_major field has duplicate values: {%s}",
+ HumanString(layout).c_str());
+ }
+ dimensions_in_layout[dim] = true;
+ }
+
+ if (layout.padded_dimensions_size() > 0) {
+ if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) {
+ return InvalidArgument(
+ "layout has %d padded dimensions, but shape is rank %lld",
+ layout.padded_dimensions_size(), ShapeUtil::Rank(shape));
+ }
+ for (int i = 0; i < layout.padded_dimensions_size(); ++i) {
+ if (layout.padded_dimensions(i) < shape.dimensions(i)) {
+ return InvalidArgument(
+ "for dimension %d, dimension padding (%lld) is smaller than "
+ "the dimension size (%lld) of the shape",
+ i, layout.padded_dimensions(i), shape.dimensions(i));
+ }
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+/* static */ void LayoutUtil::ClearLayout(Shape* shape) {
+ shape->clear_layout();
+ for (auto& element_shape : *shape->mutable_tuple_shapes()) {
+ ClearLayout(&element_shape);
+ }
+}
+
+/* static */ void LayoutUtil::ClearLayout(ProgramShape* program_shape) {
+ for (auto& parameter_shape : *program_shape->mutable_parameters()) {
+ LayoutUtil::ClearLayout(&parameter_shape);
+ }
+ LayoutUtil::ClearLayout(program_shape->mutable_result());
+}
+
+/* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) {
+ return std::is_sorted(layout.minor_to_major().begin(),
+ layout.minor_to_major().end());
+}
+
+/* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) {
+ return std::is_sorted(layout.minor_to_major().begin(),
+ layout.minor_to_major().end(), std::greater<int64>());
+}
+
+/* static */ bool LayoutUtil::IsPadded(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape) || !HasLayout(shape) ||
+ shape.layout().padded_dimensions_size() == 0) {
+ return false;
+ }
+ CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size());
+ for (int64 i = 0; i < shape.dimensions_size(); ++i) {
+ if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+/* static */ bool LayoutUtil::HasLayout(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ // Tuple shape: all subshapes must have a layout.
+ return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(),
+ [](const Shape& s) { return HasLayout(s); });
+ }
+ // A scalar trivially always has a layout.
+ return (ShapeUtil::Rank(shape) == 0 ||
+ (shape.has_layout() && (shape.layout().minor_to_major_size() > 0)));
+}
+
+/* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) {
+ for (auto& parameter_shape : program_shape.parameters()) {
+ if (!LayoutUtil::HasLayout(parameter_shape)) {
+ return false;
+ }
+ }
+ return LayoutUtil::HasLayout(program_shape.result());
+}
+
+/* static */ bool LayoutUtil::Equal(const Layout& lhs, const Layout& rhs) {
+ return protobuf_util::ProtobufEquals(lhs, rhs);
+}
+
+/* static */ int64 LayoutUtil::Major(const Layout& layout,
+ int64 physical_dimension_number) {
+ CHECK_LE(0, physical_dimension_number);
+ CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
+ return Minor(layout,
+ layout.minor_to_major_size() - 1 - physical_dimension_number);
+}
+
+/* static */ int64 LayoutUtil::Minor(const Layout& layout,
+ int64 physical_dimension_number) {
+ CHECK_LE(0, physical_dimension_number);
+ CHECK_LT(physical_dimension_number, layout.minor_to_major_size());
+ return layout.minor_to_major(physical_dimension_number);
+}
+
+/* static */ std::vector<int64> LayoutUtil::MakeLogicalToPhysical(
+ const Layout& layout) {
+ std::vector<int64> logical_to_physical(layout.minor_to_major_size());
+ for (int64 physical = 0; physical < logical_to_physical.size(); ++physical) {
+ const int64 logical = Major(layout, physical);
+ logical_to_physical[logical] = physical;
+ }
+ return logical_to_physical;
+}
+
+/* static */ string LayoutUtil::HumanString(const Layout& layout) {
+ return tensorflow::strings::StrCat(
+ "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}");
+}
+
+namespace {
+
+// Internal helper for recursively copying layouts.
+tensorflow::Status CopyLayoutInternal(const Shape& src, Shape* dst) {
+ if (ShapeUtil::IsTuple(src)) {
+ DCHECK(ShapeUtil::IsTuple(*dst));
+ DCHECK_EQ(ShapeUtil::TupleElementCount(src),
+ ShapeUtil::TupleElementCount(*dst));
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(src); ++i) {
+ TF_RETURN_IF_ERROR(CopyLayoutInternal(src.tuple_shapes(i),
+ dst->mutable_tuple_shapes(i)));
+ }
+ } else {
+ if (src.has_layout()) {
+ TF_RETURN_IF_ERROR(
+ LayoutUtil::ValidateLayoutForShape(src.layout(), *dst));
+ *dst->mutable_layout() = src.layout();
+ } else {
+ dst->clear_layout();
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+/* static */
+tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src,
+ Shape* dst) {
+ if (!ShapeUtil::Compatible(src, *dst)) {
+ return InvalidArgument(
+ "cannot copy layout from shape %s to shape %s: "
+ "shapes are not compatible",
+ ShapeUtil::HumanString(src).c_str(),
+ ShapeUtil::HumanString(*dst).c_str());
+ }
+ return CopyLayoutInternal(src, dst);
+}
+
+/* static */ bool LayoutUtil::LayoutsInShapesEqual(const Shape& lhs,
+ const Shape& rhs) {
+ if (ShapeUtil::IsTuple(lhs) != ShapeUtil::IsTuple(rhs)) {
+ return false;
+ }
+ if (ShapeUtil::IsTuple(lhs)) {
+ if (ShapeUtil::TupleElementCount(lhs) !=
+ ShapeUtil::TupleElementCount(rhs)) {
+ return false;
+ }
+ for (int i = 0; i < ShapeUtil::TupleElementCount(lhs); ++i) {
+ if (!LayoutsInShapesEqual(lhs.tuple_shapes(i), rhs.tuple_shapes(i))) {
+ return false;
+ }
+ }
+ return true;
+ } else {
+ return ShapeUtil::SameDimensions(lhs, rhs) &&
+ LayoutUtil::Equal(lhs.layout(), rhs.layout());
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
new file mode 100644
index 0000000000..984bf402cd
--- /dev/null
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -0,0 +1,153 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility functions related to layouts of Shapes.
+
+#ifndef TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Namespaced collection of (static) Layout utilities.
+class LayoutUtil {
+ public:
+ // Creates a layout with the given minor-to-major dimension order. (This is a
+ // convenience function for protobuf construction.)
+ static Layout MakeLayout(tensorflow::gtl::ArraySlice<int64> minor_to_major);
+
+ // Returns default layout for the given shape.
+ static Layout GetDefaultLayoutForShape(const Shape& shape);
+
+ // Helper functions that create default layouts for various ranks.
+ static Layout GetDefaultLayoutForR2();
+ static Layout GetDefaultLayoutForR3();
+ static Layout GetDefaultLayoutForR4();
+
+ // Sets the default layout on the Shape.
+ static void SetToDefaultLayout(Shape* shape);
+
+ // Sets the layouts of all Shapes within the given ProgramShape to the
+ // default.
+ static void SetToDefaultLayout(ProgramShape* program_shape);
+
+ // Validates that the layout within the given shape is correct.
+ static tensorflow::Status ValidateLayoutInShape(const Shape& shape);
+
+ // Validates that the provided layout satisfies invariants for the given
+ // shape.
+ static tensorflow::Status ValidateLayoutForShape(const Layout& layout,
+ const Shape& shape);
+
+ // Clears the layout in the given Shape. After this function is called,
+ // HasLayout will return false for the shape.
+ static void ClearLayout(Shape* shape);
+
+ // Clears the layout on all Shapes within the given ProgramShape.
+ static void ClearLayout(ProgramShape* program_shape);
+
+ // Returns whether the layout is monotonic and dim 0 is minor in the layout.
+ // * R0 and R1: this is always trivially true.
+ // * R2+: equivalent to column-major. Dimension 0 is the minor, dimension 1 is
+ // more major, and so on until dimension N-1 which is the major.
+ static bool IsMonotonicWithDim0Minor(const Layout& layout);
+
+ // Returns whether the layout is monotonic and dim 0 is major in the layout.
+ // * R0 and R1: this is always trivially true.
+ // * R2+: equivalent to row-major. Dimension 0 is the major, dimension 1 is
+ // more minor, and so on until dimension N-1 which is the minor.
+ static bool IsMonotonicWithDim0Major(const Layout& layout);
+
+ // Returns whether the layout of the given shape has padding (a
+ // padded_dimension value in Layout is greater than the corresponding
+ // dimension size).
+ static bool IsPadded(const Shape& shape);
+
+ // Returns whether the given shape has a layout. For tuple shapes, true is
+ // returned only if all elements have layouts.
+ static bool HasLayout(const Shape& shape);
+
+ // Returns whether all Shapes within the given ProgramShape have layouts.
+ static bool HasLayout(const ProgramShape& program_shape);
+
+ // Returns whether lhs and rhs are identical.
+ static bool Equal(const Layout& lhs, const Layout& rhs);
+
+ // Major(0) is the most major logical dimension number, major(1) is the
+ // second-most-major logical dimension number and so on.
+ //
+ // This can be used to translate physical dimension numbers to logical
+ // dimension numbers. Assume that we are numbering the physical dimensions so
+ // that the most major physical dimension has physical dimension number 0 and
+ // so on. Then a physical dimension number p corresponds to the logical
+ // dimension number Major(p). So this function could also be called
+ // PhysicalToLogical().
+ //
+ // As an example, consider physical dimension number 0, which by definition is
+ // the most major. Then Major(0) is the most major logical dimension, so Major
+ // maps the physical dimension number 0 to the most major logical dimension
+ // number Major(0).
+ static int64 Major(const Layout& layout, int64 physical_dimension_number);
+
+ // Minor(0) is the most minor logical dimension number, minor(1) is the
+ // second-most-minor logical dimension number and so on.
+ static int64 Minor(const Layout& layout, int64 physical_dimension_number);
+
+ // Returns the inverse mapping of the Major() function. More precisely, return
+ // a vector v such that if l == Major(p), then v[l] == p.
+ //
+ // This can be used to translate logical dimension numbers into physical
+ // dimension numbers. Assume that we are numbering the physical dimensions so
+ // that the most major physical dimension has physical dimension number 0 and
+ // so on. Then a logical dimension number l corresponds to the physical
+ // dimension number MakeLogicalToPhysical(layout)[l].
+ //
+ // As an example, consider physical dimension number 0, which by definition is
+ // the most major. Then l := Major(0) is the most major logical dimension. If
+ // v is the vector returned from this function, then v[l] == 0. So v maps the
+ // most major logical dimension l to the physical dimension number 0.
+ static std::vector<int64> MakeLogicalToPhysical(const Layout& layout);
+
+ // Returns a human-readable string that represents the given layout.
+ static string HumanString(const Layout& layout);
+
+ // Copies the layout from 'src' to 'dst'. Recursively copies layouts of
+ // tuples. 'src' and 'dst' must be compatible.
+ static tensorflow::Status CopyLayoutBetweenShapes(const Shape& src,
+ Shape* dst);
+
+ // Returns true if the layouts of lhs and rhs are equal, false
+ // otherwise. Recursively compares layouts of tuples.
+ //
+ // Since the structure of the shape determines the structure of the layout,
+ // this returns true if and only if the shapes and layouts are identical
+ // except that the element type is ignored.
+ static bool LayoutsInShapesEqual(const Shape& lhs, const Shape& rhs);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_
diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc
new file mode 100644
index 0000000000..5ba1122946
--- /dev/null
+++ b/tensorflow/compiler/xla/layout_util_test.cc
@@ -0,0 +1,246 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class LayoutUtilTest : public ::testing::Test {
+ protected:
+ Shape MakeShapeWithLayout(PrimitiveType element_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
+ *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+ return shape;
+ }
+};
+
+TEST_F(LayoutUtilTest, TupleLayoutComparison) {
+ Shape shape =
+ ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1})});
+ Shape other_shape =
+ ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
+
+ Shape tuple0 = ShapeUtil::MakeTupleShape({});
+ Shape tuple1 = ShapeUtil::MakeTupleShape({shape});
+ Shape tuple2 = ShapeUtil::MakeTupleShape({shape, shape});
+
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple0));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple1));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple0, tuple2));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple0));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple0));
+
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple1));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple1, tuple2));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple1));
+
+ Shape other_tuple2 = ShapeUtil::MakeTupleShape({shape, other_shape});
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(tuple2, tuple2));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(tuple2, other_tuple2));
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(other_tuple2, tuple2));
+}
+
+TEST_F(LayoutUtilTest, CopyLayoutArray) {
+ Shape src = MakeShapeWithLayout(F32, {2, 3}, {0, 1});
+ Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
+
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+
+ // Should work if destination has no layout.
+ dst.clear_layout();
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+
+ // If source is cleared, then destination should be cleared.
+ src.clear_layout();
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_TRUE(dst.has_layout());
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_FALSE(dst.has_layout());
+}
+
+TEST_F(LayoutUtilTest, CopyLayoutTuple) {
+ Shape src = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
+ ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {0, 2, 1})})});
+ Shape dst = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
+ ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
+
+ EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+ EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst));
+}
+
+TEST_F(LayoutUtilTest, CopyLayoutNotCompatible) {
+ Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1});
+ Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0});
+ auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
+ EXPECT_FALSE(status.ok());
+ EXPECT_MATCH(status.error_message(),
+ testing::ContainsRegex("cannot copy layout from shape"));
+}
+
+TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) {
+ Shape src =
+ ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
+ ShapeUtil::MakeTupleShape({MakeShapeWithLayout(
+ F32, {1, 2, 3}, {0, 2, 1})})});
+ Shape dst = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
+ ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
+
+ auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
+ EXPECT_FALSE(status.ok());
+ EXPECT_MATCH(status.error_message(),
+ testing::ContainsRegex("cannot copy layout from shape"));
+}
+
+TEST_F(LayoutUtilTest, CopyLayoutBogusLayout) {
+ Shape src = ShapeUtil::MakeShape(F32, {2, 3});
+ Shape dst = ShapeUtil::MakeShape(F32, {2, 3});
+ // Set layout to invalid value.
+ *src.mutable_layout() = LayoutUtil::MakeLayout({1, 2, 3, 4});
+
+ auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst);
+ EXPECT_FALSE(status.ok());
+ EXPECT_MATCH(status.error_message(),
+ testing::ContainsRegex("layout minor_to_major field contains .* "
+ "elements, but shape is rank"));
+}
+
+TEST_F(LayoutUtilTest, ClearLayoutTuple) {
+ Shape shape = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3}, {1, 0}),
+ MakeShapeWithLayout(F32, {42, 123}, {1, 0}),
+ ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})});
+ EXPECT_TRUE(LayoutUtil::HasLayout(shape));
+ EXPECT_TRUE(shape.tuple_shapes(0).has_layout());
+ EXPECT_TRUE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
+
+ LayoutUtil::ClearLayout(&shape);
+
+ EXPECT_FALSE(LayoutUtil::HasLayout(shape));
+ EXPECT_FALSE(shape.tuple_shapes(0).has_layout());
+ EXPECT_FALSE(shape.tuple_shapes(2).tuple_shapes(1).has_layout());
+}
+
+TEST_F(LayoutUtilTest, SetToDefaultLayoutTuple) {
+ Shape shape = ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {2, 3, 4}, {1, 0, 2}),
+ MakeShapeWithLayout(F32, {42, 123, 7}, {1, 2, 0}),
+ ShapeUtil::MakeTupleShape(
+ {MakeShapeWithLayout(F32, {}, {}),
+ MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 1, 2, 0})})});
+ EXPECT_FALSE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
+ shape.tuple_shapes(1).layout()));
+ LayoutUtil::SetToDefaultLayout(&shape);
+ EXPECT_TRUE(LayoutUtil::Equal(shape.tuple_shapes(0).layout(),
+ shape.tuple_shapes(1).layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(
+ LayoutUtil::GetDefaultLayoutForShape(shape.tuple_shapes(0)),
+ shape.tuple_shapes(1).layout()));
+}
+
+TEST_F(LayoutUtilTest, IsPadded) {
+ Shape shape_without_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
+ LayoutUtil::ClearLayout(&shape_without_layout);
+ EXPECT_FALSE(LayoutUtil::IsPadded(shape_without_layout));
+
+ Shape shape_with_layout = ShapeUtil::MakeShape(F32, {2, 3, 4});
+ LayoutUtil::SetToDefaultLayout(&shape_with_layout);
+ EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_layout));
+
+ // Add padding equal to the dimension sizes. In this case the padding is a
+ // nop.
+ Shape shape_with_degenerate_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
+ shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(2);
+ shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(3);
+ shape_with_degenerate_padding.mutable_layout()->add_padded_dimensions(4);
+ EXPECT_FALSE(LayoutUtil::IsPadded(shape_with_degenerate_padding));
+
+ Shape shape_with_padding = ShapeUtil::MakeShape(F32, {2, 3, 4});
+ shape_with_padding.mutable_layout()->add_padded_dimensions(2);
+ shape_with_padding.mutable_layout()->add_padded_dimensions(14);
+ shape_with_padding.mutable_layout()->add_padded_dimensions(42);
+ EXPECT_TRUE(LayoutUtil::IsPadded(shape_with_padding));
+}
+
+TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) {
+ // Test that LayoutUtil returns expected layouts when the xla_default_layout
+ // flag is set to kMajorToMinor.
+ legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags();
+ flags->xla_default_layout = xla::legacy_flags::DefaultLayout{
+ .dimension_order =
+ legacy_flags::DefaultLayout::DimensionOrder::kMajorToMinor};
+
+ EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({1, 0}),
+ LayoutUtil::GetDefaultLayoutForR2()));
+ EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({2, 1, 0}),
+ LayoutUtil::GetDefaultLayoutForR3()));
+ EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({3, 2, 1, 0}),
+ LayoutUtil::GetDefaultLayoutForR4()));
+ EXPECT_TRUE(
+ LayoutUtil::Equal(LayoutUtil::MakeLayout({4, 3, 2, 1, 0}),
+ LayoutUtil::GetDefaultLayoutForShape(
+ ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
+}
+
+TEST_F(LayoutUtilTest, DefaultLayoutGettersMinorToMajor) {
+ // Test that LayoutUtil returns expected layouts when the xla_default_layout
+ // flag is set to kMinorToMajor.
+ legacy_flags::LayoutUtilFlags* flags = legacy_flags::GetLayoutUtilFlags();
+ flags->xla_default_layout = xla::legacy_flags::DefaultLayout{
+ .dimension_order =
+ legacy_flags::DefaultLayout::DimensionOrder::kMinorToMajor};
+
+ EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}),
+ LayoutUtil::GetDefaultLayoutForR2()));
+ EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}),
+ LayoutUtil::GetDefaultLayoutForR3()));
+ EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2, 3}),
+ LayoutUtil::GetDefaultLayoutForR4()));
+ EXPECT_TRUE(
+ LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2, 3, 4}),
+ LayoutUtil::GetDefaultLayoutForShape(
+ ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25}))));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD
new file mode 100644
index 0000000000..c98232cdf6
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/BUILD
@@ -0,0 +1,267 @@
+# Legacy command line flags for the XLA libraries.
+
+# Please do not add more flags to this package.
+
+# The XLA libraries were written in an environment that allowed command - line
+# flags to be scattered freely throughout the libraries. This model, while
+# initially convenient, leads to a proliferation in unused commnd line flags in
+# tests and binaries, and serious problems in servers, where one might wish
+# parameters to be different in independent RPC calls to the same routine.
+#
+# Please don't add more flags. If you're a library author, pass options and
+# parameters explicitly through the library's interface.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "parse_flags_from_env",
+ srcs = ["parse_flags_from_env.cc"],
+ hdrs = ["parse_flags_from_env.h"],
+ deps =
+ [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "parse_flags_from_env_test",
+ srcs = ["parse_flags_from_env_test.cc"],
+ deps =
+ [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "layout_util_flags",
+ srcs = ["layout_util_flags.cc"],
+ hdrs = ["layout_util_flags.h"],
+ deps =
+ [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "util_flags",
+ srcs = ["util_flags.cc"],
+ hdrs = ["util_flags.h"],
+ deps =
+ [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "cpu_compiler_flags",
+ srcs = ["cpu_compiler_flags.cc"],
+ hdrs = ["cpu_compiler_flags.h"],
+ deps =
+ [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "cpu_runtime_flags",
+ srcs = ["cpu_runtime_flags.cc"],
+ hdrs = ["cpu_runtime_flags.h"],
+ deps =
+ [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "llvm_backend_flags",
+ srcs = ["llvm_backend_flags.cc"],
+ hdrs = ["llvm_backend_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "compiler_functor_flags",
+ srcs = ["compiler_functor_flags.cc"],
+ hdrs = ["compiler_functor_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "convolution_thunk_flags",
+ srcs = ["convolution_thunk_flags.cc"],
+ hdrs = ["convolution_thunk_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "gpu_compiler_flags",
+ srcs = ["gpu_compiler_flags.cc"],
+ hdrs = ["gpu_compiler_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "gpu_backend_lib_flags",
+ srcs = ["gpu_backend_lib_flags.cc"],
+ hdrs = ["gpu_backend_lib_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "stream_assignment_flags",
+ srcs = ["stream_assignment_flags.cc"],
+ hdrs = ["stream_assignment_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_graph_dumper_flags",
+ srcs = ["hlo_graph_dumper_flags.cc"],
+ hdrs = ["hlo_graph_dumper_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_pass_pipeline_flags",
+ srcs = ["hlo_pass_pipeline_flags.cc"],
+ hdrs = ["hlo_pass_pipeline_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "alias_analysis_flags",
+ srcs = ["alias_analysis_flags.cc"],
+ hdrs = ["alias_analysis_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "llvm_util_flags",
+ srcs = ["llvm_util_flags.cc"],
+ hdrs = ["llvm_util_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "service_flags",
+ srcs = ["service_flags.cc"],
+ hdrs = ["service_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "buffer_assignment_flags",
+ srcs = ["buffer_assignment_flags.cc"],
+ hdrs = ["buffer_assignment_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_test_base_flags",
+ srcs = ["hlo_test_base_flags.cc"],
+ hdrs = ["hlo_test_base_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "backend_flags",
+ srcs = ["backend_flags.cc"],
+ hdrs = ["backend_flags.h"],
+ deps = [
+ ":parse_flags_from_env",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc
new file mode 100644
index 0000000000..474753c10a
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's alias_analysis module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static AliasAnalysisFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new AliasAnalysisFlags;
+ flags->xla_emit_alias_scope = true;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope,
+ "Use buffer analysis to refine alias-analysis."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's alias_analysis
+// module.
+void AppendAliasAnalysisFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the AliasAnalysisFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+AliasAnalysisFlags* GetAliasAnalysisFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h
new file mode 100644
index 0000000000..369f8cd7ca
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
+
+// Legacy flags for XLA's alias_analysis module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's alias_analysis
+// module.
+void AppendAliasAnalysisFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's alias_analysis module.
+typedef struct {
+ bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis.
+} AliasAnalysisFlags;
+
+// Return a pointer to the AliasAnalysisFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+AliasAnalysisFlags* GetAliasAnalysisFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/backend_flags.cc b/tensorflow/compiler/xla/legacy_flags/backend_flags.cc
new file mode 100644
index 0000000000..7c007f4435
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/backend_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's backend module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static BackendFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new BackendFlags;
+ // TODO(b/32648682): Decide if this should continue to be a flag longer term.
+ flags->xla_replicas = 1;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_replicas", &flags->xla_replicas,
+ "The number of replicas to use. 1 means no replication."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's backend module.
+void AppendBackendFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the BackendFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+BackendFlags* GetBackendFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/backend_flags.h b/tensorflow/compiler/xla/legacy_flags/backend_flags.h
new file mode 100644
index 0000000000..061238b7e6
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/backend_flags.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_
+
+// Legacy flags for XLA's backend module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's backend module.
+void AppendBackendFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's backend module.
+typedef struct {
+ int64 xla_replicas; // The number of replicas to use. 1 means no
+ // replication.
+} BackendFlags;
+
+// Return a pointer to the BackendFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+BackendFlags* GetBackendFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BACKEND_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc
new file mode 100644
index 0000000000..71873f73af
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's buffer_assignment module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static BufferAssignmentFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new BufferAssignmentFlags;
+ flags->xla_enable_buffer_reuse = true;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_enable_buffer_reuse",
+ &flags->xla_enable_buffer_reuse,
+ "Enable reuse of buffers."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's buffer_assignment
+// module.
+void AppendBufferAssignmentFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the BufferAssignmentFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+BufferAssignmentFlags* GetBufferAssignmentFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h
new file mode 100644
index 0000000000..5f098c2663
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_
+
+// Legacy flags for XLA's buffer_assignment module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's buffer_assignment
+// module.
+void AppendBufferAssignmentFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's buffer_assignment module.
+typedef struct {
+ bool xla_enable_buffer_reuse; // Enable reuse of buffers.
+} BufferAssignmentFlags;
+
+// Return a pointer to the BufferAssignmentFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+BufferAssignmentFlags* GetBufferAssignmentFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_BUFFER_ASSIGNMENT_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc
new file mode 100644
index 0000000000..617a9b712e
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's compiler_functor module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static CompilerFunctorFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new CompilerFunctorFlags;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_debug_cpu_dump_ir", &flags->xla_debug_cpu_dump_ir,
+ "Dump IR, before optimizations to a path"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's compiler_functor
+// module.
+void AppendCompilerFunctorFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the CompilerFunctorFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+CompilerFunctorFlags* GetCompilerFunctorFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h
new file mode 100644
index 0000000000..28b505ec5e
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_
+
+// Legacy flags for the XLA's compiler_functor module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's compiler_functor
+// module.
+void AppendCompilerFunctorFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's compiler_functor module.
+typedef struct {
+ string xla_debug_cpu_dump_ir; // Dump IR, before optimizations to a path
+} CompilerFunctorFlags;
+
+// Return a pointer to the CompilerFunctorFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+CompilerFunctorFlags* GetCompilerFunctorFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_COMPILER_FUNCTOR_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc
new file mode 100644
index 0000000000..fe5d19147f
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's convolution_thunk module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static ConvolutionThunkFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new ConvolutionThunkFlags;
+ flags->xla_gpu_autotune_convolution_algorithm = true;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_gpu_autotune_convolution_algorithm",
+ &flags->xla_gpu_autotune_convolution_algorithm,
+ "Auto-tune the algorithm used by convolution"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's convolution_thunk
+// module.
+void AppendConvolutionThunkFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the ConvolutionThunkFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+ConvolutionThunkFlags* GetConvolutionThunkFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h
new file mode 100644
index 0000000000..53d6806a71
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_
+
+// Legacy flags for XLA's convolution_thunk module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's convolution_thunk
+// module.
+void AppendConvolutionThunkFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's convolution_thunk module.
+typedef struct {
+ // Auto-tune the algorithm used by convolution
+ bool xla_gpu_autotune_convolution_algorithm;
+} ConvolutionThunkFlags;
+
+// Return a pointer to the ConvolutionThunkFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+ConvolutionThunkFlags* GetConvolutionThunkFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CONVOLUTION_THUNK_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc
new file mode 100644
index 0000000000..f8ae25552d
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's cpu_compiler module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static CpuCompilerFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new CpuCompilerFlags;
+ flags->xla_cpu_llvm_opt_level = 2;
+ flags->xla_cpu_llvm_cl_opts = "";
+ flags->xla_cpu_embed_ir = false;
+ flags->xla_cpu_parallel = false;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_cpu_llvm_opt_level", &flags->xla_cpu_llvm_opt_level,
+ "The LLVM optimization level for the CPU XLA backend. "
+ "Valid range is from 0 to 3 where 0 means no optimizations."),
+ tensorflow::Flag(
+ "xla_cpu_llvm_cl_opts", &flags->xla_cpu_llvm_cl_opts,
+ "Comma-separated list of command line options to pass to LLVM."),
+ tensorflow::Flag(
+ "xla_cpu_embed_ir", &flags->xla_cpu_embed_ir,
+ "Embed the LLVM IR module string in the resultant CpuExecutable."),
+ tensorflow::Flag("xla_cpu_parallel", &flags->xla_cpu_parallel,
+ "Use the multi-threaded CPU backend."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's cpu_compiler
+// module.
+void AppendCpuCompilerFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the CpuCompilerFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+CpuCompilerFlags* GetCpuCompilerFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h
new file mode 100644
index 0000000000..16a7b68711
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
+
+// Legacy flags for the XLA's cpu_compiler module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's cpu_compiler
+// module.
+void AppendCpuCompilerFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's cpu_compiler module.
+typedef struct {
+ // The LLVM optimization level for the CPU XLA backend.
+ // Valid range is from 0 to 3 where 0 means no optimizations.
+ int32 xla_cpu_llvm_opt_level;
+ string xla_cpu_llvm_cl_opts; // Comma-separated list of command line options
+ // to pass to LLVM.
+ bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant
+ // CpuExecutable
+ bool xla_cpu_parallel; // Use the multi-threaded CPU backend.
+} CpuCompilerFlags;
+
+// Return a pointer to the CpuCompilerFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+CpuCompilerFlags* GetCpuCompilerFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc
new file mode 100644
index 0000000000..d7817c5d54
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.cc
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's cpu_runtime module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static CpuRuntimeFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new CpuRuntimeFlags;
+ flags->xla_cpu_use_eigen = true;
+ flags->xla_cpu_multi_thread_eigen = true;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_cpu_use_eigen", &flags->xla_cpu_use_eigen,
+ "Use Eigen for matrix multiply on the CPU platform. This "
+ "is a useful hack for performance comparisons against "
+ "XLA's implementation."),
+ tensorflow::Flag(
+ "xla_cpu_multi_thread_eigen", &flags->xla_cpu_multi_thread_eigen,
+ "When generating calls to Eigen for matmul and conv, should "
+ "single or multi-threaded eigen be used? "
+ "Only used when --xla_cpu_use_eigen is true."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's cpu_runtime
+// module.
+void AppendCpuRuntimeFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the CpuRuntimeFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+CpuRuntimeFlags* GetCpuRuntimeFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h
new file mode 100644
index 0000000000..e3ff30da36
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
+
+// Legacy flags for the XLA's cpu_runtime module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's cpu_runtime
+// module.
+void AppendCpuRuntimeFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's cpu_runtime module.
+typedef struct {
+ // Use Eigen for matrix multiply on the CPU platform. This is a useful hack
+ // for performance comparisons against XLA's implementation.
+ bool xla_cpu_use_eigen;
+ // When generating calls to Eigen for matmul and conv, should single or
+ // multi-threaded eigen be used? Only used when --xla_cpu_use_eigen is true.
+ bool xla_cpu_multi_thread_eigen;
+} CpuRuntimeFlags;
+
+// Return a pointer to the CpuRuntimeFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+CpuRuntimeFlags* GetCpuRuntimeFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_RUNTIME_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc
new file mode 100644
index 0000000000..c355b1ed9b
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's gpu_backend_lib module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static GpuBackendLibFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new GpuBackendLibFlags;
+ flags->dump_temp_products_to = "";
+ flags->ftz = false;
+ flags->fma = true;
+ flags->gpu_architecture = "compute_35";
+ flags->verbose_ptx_asm = false;
+ flags->kernel = "";
+ flags->llvm_dump_passes = false;
+ flags->llvm_cl_opts = "";
+ flags->dump_ir_before_passes = false;
+ flags->opt_level = 3;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to,
+ "dump temporary compilation products to this directory. "
+ "If empty, no dump is produced"),
+ tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"),
+ tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"),
+ tensorflow::Flag("gpu_architecture", &flags->gpu_architecture,
+ "GPU architecture"),
+ tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm,
+ "emit PTX assembly with extra comments"),
+ tensorflow::Flag("kernel", &flags->kernel,
+ "only emit the IR and PTX for this kernel"),
+ tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes,
+ "dump the passes LLVM runs to stderr"),
+ tensorflow::Flag(
+ "llvm_cl_opts", &flags->llvm_cl_opts,
+ "comma-separated list of command line options to pass to "
+ "LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"),
+ tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes,
+ "dump the IR before each optimization pass in "
+ "sequentially-named files."),
+ tensorflow::Flag("opt_level", &flags->opt_level,
+ "optimization level (default to 3)"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's gpu_backend_lib
+// module.
+void AppendGpuBackendLibFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the GpuBackendLibFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+GpuBackendLibFlags* GetGpuBackendLibFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h
new file mode 100644
index 0000000000..fbb8863454
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
+
+// Legacy flags for XLA's gpu_backend_lib module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib
+// module.
+void AppendGpuBackendLibFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's gpu_backend_lib module.
+typedef struct {
+ string dump_temp_products_to; // temporary compilation products dir
+ bool ftz; // flush to zero semantics
+ bool fma; // use FMA synthesis
+ string gpu_architecture; // GPU architecture
+ bool verbose_ptx_asm; // emit PTX assembly with extra comments
+ string kernel; // only emit the IR and PTX for this kernel
+ bool llvm_dump_passes; // dump the passes LLVM runs to stderr
+ string llvm_cl_opts; // comma-separated list of LLVM options
+ bool dump_ir_before_passes; // dump IR before each pass
+ int32 opt_level; // optimization level
+} GpuBackendLibFlags;
+
+// Return a pointer to the GpuBackendLibFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+GpuBackendLibFlags* GetGpuBackendLibFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc
new file mode 100644
index 0000000000..e79d363509
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's gpu_compiler module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static GpuCompilerFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new GpuCompilerFlags;
+ flags->xla_gpu_embed_ir = false;
+ flags->xla_cuda_data_dir = "./cuda_sdk_lib";
+ flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas";
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir,
+ "Embed the LLVM IR module string in the resultant GpuExecutable."),
+ tensorflow::Flag(
+ "xla_cuda_data_dir", &flags->xla_cuda_data_dir,
+ "If non-empty, specifies a local directory containing ptxas and "
+ "nvvm libdevice files. Otherwise, by default, we use those from "
+ "runfile directories."),
+ tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path,
+ "The path to ptxas. Required to log stats of the ptx."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's gpu_compiler
+// module.
+void AppendGpuCompilerFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the GpuCompilerFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+GpuCompilerFlags* GetGpuCompilerFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h
new file mode 100644
index 0000000000..04ddedab73
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
+
+// Legacy flags for XLA's gpu_compiler module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's gpu_compiler
+// module.
+void AppendGpuCompilerFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's gpu_compiler module.
+typedef struct {
+ bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant
+ // GpuExecutable.
+ string xla_cuda_data_dir; // If non-empty, specifies a local directory
+ // containing ptxas and nvvm libdevice files.
+ // Otherwise, by default, we use those from runfile
+ // directories.
+ string xla_ptxas_path; // The path to ptxas. Required to log stats of
+ // the ptx.
+} GpuCompilerFlags;
+
+// Return a pointer to the GpuCompilerFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+GpuCompilerFlags* GetGpuCompilerFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc
new file mode 100644
index 0000000000..8822f6f610
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's hlo_graph_dumper module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static HloGraphDumperFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new HloGraphDumperFlags;
+ flags->xla_hlo_dump_graph_path = "/tmp/";
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_hlo_dump_graph_path",
+ &flags->xla_hlo_dump_graph_path,
+ "Path to write dumped HLO graphs to"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's hlo_graph_dumper
+// module.
+void AppendHloGraphDumperFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the HloGraphDumperFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+HloGraphDumperFlags* GetHloGraphDumperFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h
new file mode 100644
index 0000000000..b6dfced87c
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_
+
+// Legacy flags for XLA's hlo_graph_dumper module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's hlo_graph_dumper
+// module.
+void AppendHloGraphDumperFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's hlo_graph_dumper module.
+typedef struct {
+ string xla_hlo_dump_graph_path; // Path to write dumped HLO graphs to
+} HloGraphDumperFlags;
+
+// Return a pointer to the HloGraphDumperFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+HloGraphDumperFlags* GetHloGraphDumperFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_GRAPH_DUMPER_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc
new file mode 100644
index 0000000000..edc04d51a7
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's hlo_pass_pipeline module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static HloPassPipelineFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new HloPassPipelineFlags;
+ flags->xla_disable_hlo_passes = "";
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_disable_hlo_passes", &flags->xla_disable_hlo_passes,
+ "Comma-separated list of HLO passes to disable."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline
+// module.
+void AppendHloPassPipelineFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the HloPassPipelineFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+HloPassPipelineFlags* GetHloPassPipelineFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h
new file mode 100644
index 0000000000..520759bbf0
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_
+
+// Legacy flags for XLA's hlo_pass_pipeline module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's hlo_pass_pipeline
+// module.
+void AppendHloPassPipelineFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's hlo_pass_pipeline module.
+typedef struct {
+ // Comma-separated list of HLO passes to disable.
+ string xla_disable_hlo_passes;
+} HloPassPipelineFlags;
+
+// Return a pointer to the HloPassPipelineFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+HloPassPipelineFlags* GetHloPassPipelineFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_PASS_PIPELINE_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc
new file mode 100644
index 0000000000..c7893c1385
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's hlo_test_base module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static HloTestBaseFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new HloTestBaseFlags;
+ flags->xla_hlo_test_generate_hlo_graph = false;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_hlo_test_generate_hlo_graph",
+ &flags->xla_hlo_test_generate_hlo_graph,
+ "Generate graph output of HLO instructions"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's hlo_test_base
+// module.
+void AppendHloTestBaseFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the HloTestBaseFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+HloTestBaseFlags* GetHloTestBaseFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h
new file mode 100644
index 0000000000..23b808cecb
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_
+
+// Legacy flags for XLA's hlo_test_base module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's hlo_test_base
+// module.
+void AppendHloTestBaseFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's hlo_test_base module.
+typedef struct {
+ bool xla_hlo_test_generate_hlo_graph; // Generate graph output of HLO
+ // instructions
+} HloTestBaseFlags;
+
+// Return a pointer to the HloTestBaseFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+HloTestBaseFlags* GetHloTestBaseFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_HLO_TEST_BASE_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc
new file mode 100644
index 0000000000..4242b501d4
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's layout_util module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the string value of the xla_default_layout flag and the flag
+// descriptor, initialized via raw_flags_init.
+static string* raw_flag;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag raw_flags_init;
+
+// Allocate *raw_flag. Called via call_once(&raw_flags_init,...).
+static void AllocateRawFlag() {
+ raw_flag = new string;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_default_layout", raw_flag,
+ "Default layout for Shapes in XLA. Valid values are: "
+ "'minor2major', 'major2minor', 'random', 'random:<seed>'. "
+ "For debugging purposes. If no seed (or 0) is given, a seed from "
+ "random_device is used."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Parse text into *layout.
+static bool ParseDefaultLayout(const string& text, DefaultLayout* layout) {
+ bool result = true;
+ std::vector<string> field = tensorflow::str_util::Split(text, ':');
+ if (field.size() > 0) {
+ if (field[0] == "random") {
+ layout->dimension_order = DefaultLayout::DimensionOrder::kRandom;
+ if (field.size() > 1) {
+ uint64 seed = 0;
+ result = tensorflow::strings::safe_strtou64(field[1], &seed);
+ layout->seed = seed;
+ }
+ } else if (field[0] == "minor2major") {
+ layout->dimension_order = DefaultLayout::DimensionOrder::kMinorToMajor;
+ } else if (field[0] == "major2minor") {
+ layout->dimension_order = DefaultLayout::DimensionOrder::kMajorToMinor;
+ } else {
+ result = false;
+ }
+ }
+ return result;
+}
+
+// Pointer to the parsed value of the flags, initialized via flags_init.
+static LayoutUtilFlags* flags;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ std::call_once(raw_flags_init, &AllocateRawFlag);
+ flags = new LayoutUtilFlags;
+ flags->xla_default_layout.dimension_order =
+ DefaultLayout::DimensionOrder::kMajorToMinor;
+ flags->xla_default_layout.seed = 0;
+ if (!ParseDefaultLayout(*raw_flag, &flags->xla_default_layout)) {
+ flags = nullptr;
+ }
+}
+
+// Append to *append_to the flag definitions associated with XLA's layout_util
+// module.
+void AppendLayoutUtilFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(raw_flags_init, &AllocateRawFlag);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the LayoutUtilFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+LayoutUtilFlags* GetLayoutUtilFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h
new file mode 100644
index 0000000000..177f428b73
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/layout_util_flags.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_
+
+// Legacy flags for the XLA's layout_util module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// The default layout for all newly created shapes. Specified by the flag
+// --xla_default_layout.
+struct DefaultLayout {
+ enum class DimensionOrder {
+ kRandom,
+ kMinorToMajor,
+ kMajorToMinor,
+ };
+
+ DimensionOrder dimension_order;
+ size_t seed;
+};
+
+// Append to *flag_list the flag definitions associated with XLA's layout_util
+// module.
+void AppendLayoutUtilFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's layout_util module.
+typedef struct {
+ // Default layout for Shapes in XLA. Valid values are: 'minor2major',
+ // 'major2minor', 'random', 'random:<seed>'. For debugging purposes. If no
+ // seed (or 0) is given, a seed from random_device is used.
+ DefaultLayout xla_default_layout;
+} LayoutUtilFlags;
+
+// Return a pointer to the LayoutFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+LayoutUtilFlags* GetLayoutUtilFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LAYOUT_UTIL_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc
new file mode 100644
index 0000000000..c8a71b284f
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.cc
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags associated with XLA's use of LLVM for code generation.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static LlvmBackendFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new LlvmBackendFlags;
+ flags->xla_fast_math = true;
+ flags->xla_precision_losing_optimizations = true;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_precision_losing_optimizations",
+ &flags->xla_precision_losing_optimizations,
+ "Allows llvm to make transformations that reduce the precision of "
+ "floating-point computations. This is equivalent to clang's "
+ "-funsafe-math-optimizations flag."),
+ tensorflow::Flag(
+ "xla_fast_math", &flags->xla_fast_math,
+ "Allows llvm to make all manner of unsafe floating-point "
+ "optimizations, including assuming that NaN and Inf don't appear. "
+ "This is equivalent to clang's -ffast-math flag."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+void AppendLlvmBackendFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+LlvmBackendFlags* GetLlvmBackendFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h
new file mode 100644
index 0000000000..e8c0489285
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_BACKEND_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_BACKEND_FLAGS_H_
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's use of LLVM for
+// code generation.
+void AppendLlvmBackendFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's use of LLVM for code generation.
+typedef struct {
+ // Allows llvm to make transformations that reduce the precision of
+ // floating-point computations, but it *does not* allow it to disregard signed
+ // zero or assume that NaN and Inf never appear.
+ //
+ // Controls the "UnsafeFPMath" LLVM target option and
+ // llvm::FastMathFlags::allowReciprocal. This is equivalent to clang's
+ // -funsafe-math-optimizations flag.
+ bool xla_precision_losing_optimizations;
+
+ // Unleashes the full power of LLVM's unsafe floating-point optimizations.
+ // Everything is fair game, including disregarding signed zero and assuming
+ // that NaN and Inf never appear.
+ //
+ // This implies xla_precision_losing_optimizations, and is equivalent to
+ // clang's -ffast-math flag.
+ bool xla_fast_math;
+} LlvmBackendFlags;
+
+// Return a pointer to the LlvmBackendFlags struct. Repeated calls return the
+// same pointer. This should be called only after Flags::Parse() has returned.
+LlvmBackendFlags* GetLlvmBackendFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_BACKEND_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc
new file mode 100644
index 0000000000..3c53729a67
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's llvm_util module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static LlvmUtilFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new LlvmUtilFlags;
+ flags->xla_emit_tbaa = true;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa,
+ "Perform type-based alias analysis optimizations for "
+ "LLVM-based backends."),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's llvm_util
+// module.
+void AppendLlvmUtilFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the LlvmUtilFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+LlvmUtilFlags* GetLlvmUtilFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h
new file mode 100644
index 0000000000..98da26b4b8
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
+
+// Legacy flags for XLA's llvm_util module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's llvm_util module.
+void AppendLlvmUtilFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's llvm_util module.
+typedef struct {
+ bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for
+ // LLVM-based backends.
+} LlvmUtilFlags;
+
+// Return a pointer to the LlvmUtilFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+LlvmUtilFlags* GetLlvmUtilFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc
new file mode 100644
index 0000000000..2a4e49b05a
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.cc
@@ -0,0 +1,206 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This module exports ParseFlagsFromEnv(), which allows other modules to parse
+// flags from an environtment variable, or a file named by the environment
+// variable.
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+static const char kEnvVar[] = "TF_XLA_FLAGS"; // environment variable queried
+static const char kWS[] = " \t\r\n"; // whitespace
+
+// The following struct represents an argv[]-style array, parsed
+// from data gleaned from the environment.
+//
+// As usual, an anonymous namespace is advisable to avoid
+// constructor/destructor collisions with other "private" types
+// in the same named namespace.
+namespace {
+struct EnvArgv {
+ EnvArgv() : initialized(false), argc(0) {}
+ bool initialized; // whether the other fields have been set.
+ int argc; // elements used in argv[]
+ std::vector<char*> argv; // flag arguments parsed from environment string.
+ std::vector<char*> argv_save; // saved values from argv[] to avoid leaks
+};
+} // anonymous namespace
+
+// Append the string s0[0, .., s0len-1] concatenated with s1[0, .., s1len-1] as
+// a newly allocated nul-terminated string to the array *a. If s0==nullptr, a
+// nullptr is appended without increasing a->argc.
+static void AppendToEnvArgv(const char* s0, size_t s0len, const char* s1,
+ size_t s1len, EnvArgv* a) {
+ if (s0 == nullptr) {
+ a->argv.push_back(nullptr);
+ a->argv_save.push_back(nullptr);
+ } else {
+ string s = string(s0, s0len) + string(s1, s1len);
+ char* str = strdup(s.c_str());
+ a->argv.push_back(str);
+ a->argv_save.push_back(str);
+ a->argc++;
+ }
+}
+
+// Like s.find_first_of(x, pos), but return s.size() when find_first_of() would
+// return string::npos. This avoids if-statements elsewhere.
+static size_t FindFirstOf(const string& s, const char* x, size_t pos) {
+ size_t result = s.find_first_of(x, pos);
+ return result == string::npos ? s.size() : result;
+}
+
+// Like s.find_first_not_of(x, pos), but return s.size() when
+// find_first_not_of() would return string::npos. This avoids if-statements
+// elsewhere.
+static size_t FindFirstNotOf(const string& s, const char* x, size_t pos) {
+ size_t result = s.find_first_not_of(x, pos);
+ return result == string::npos ? s.size() : result;
+}
+
+// Given a string containing flags, parse them into the XLA command line flags.
+// The parse is best effort, and gives up on the first syntax error.
+static void ParseArgvFromString(const string& flag_str, EnvArgv* a) {
+ size_t b = FindFirstNotOf(flag_str, kWS, 0);
+ while (b != flag_str.size() && flag_str[b] == '-') {
+ // b is the index of the start of a flag.
+ // Set e to the index just past the end of the flag.
+ size_t e = b;
+ while (e != flag_str.size() && isascii(flag_str[e]) &&
+ (strchr("-_", flag_str[e]) != nullptr || isalnum(flag_str[e]))) {
+ e++;
+ }
+ if (e != flag_str.size() && flag_str[e] == '=' &&
+ e + 1 != flag_str.size() && strchr("'\"", flag_str[e + 1]) != nullptr) {
+ // A flag of the form --flag="something in double or single quotes"
+ int c;
+ e++; // point just past '='
+ size_t eflag = e;
+ char quote = flag_str[e];
+ e++; // point just past quote
+ // Put in value the string with quotes removed.
+ string value;
+ for (; e != flag_str.size() && (c = flag_str[e]) != quote; e++) {
+ if (quote == '"' && c == '\\' && e + 1 != flag_str.size()) {
+ // Handle backslash in double quoted strings. They are literal in
+ // single-quoted strings.
+ e++;
+ c = flag_str[e];
+ }
+ value += c;
+ }
+ if (e != flag_str.size()) { // skip final " or '
+ e++;
+ }
+ AppendToEnvArgv(flag_str.data() + b, eflag - b, value.data(),
+ value.size(), a);
+ } else { // A flag without a quoted value.
+ e = FindFirstOf(flag_str, kWS, e);
+ AppendToEnvArgv(flag_str.data() + b, e - b, "", 0, a);
+ }
+ b = FindFirstNotOf(flag_str, kWS, e);
+ }
+}
+
+// Call ParseArgvFromString(..., a) on a string derived from the setting of an
+// environment variable kEnvVar, or a file it points to.
+static void SetArgvFromEnv(EnvArgv* a) {
+ if (!a->initialized) {
+ static const char kDummyArgv[] = "<argv[0]>";
+ AppendToEnvArgv(kDummyArgv, strlen(kDummyArgv), nullptr, 0,
+ a); // dummy argv[0]
+ const char* env = getenv(kEnvVar);
+ if (env == nullptr || env[0] == '\0') {
+ // nothing
+ } else if (env[strspn(env, kWS)] == '-') { // flags in env var value
+ ParseArgvFromString(env, a);
+ } else { // assume it's a file name
+ FILE* fp = fopen(env, "r");
+ if (fp != nullptr) {
+ string str;
+ char buf[512];
+ int n;
+ while ((n = fread(buf, 1, sizeof(buf), fp)) > 0) {
+ str.append(buf, n);
+ }
+ fclose(fp);
+ ParseArgvFromString(str, a);
+ }
+ }
+ AppendToEnvArgv(nullptr, 0, nullptr, 0, a); // add trailing nullptr to *a.
+ a->initialized = true;
+ }
+}
+
+// The simulated argv[] parsed from the environment.
+static EnvArgv* env_argv;
+
+// Used to protect accesses to env_argv.
+static tensorflow::mutex env_argv_mu(tensorflow::LINKER_INITIALIZED);
+
+// Call Flags::Parse(argc, argv, flag_list) against any as yet unrecognized
+// flags passed in from the environment.
+bool ParseFlagsFromEnv(const std::vector<tensorflow::Flag>& flag_list) {
+ env_argv_mu.lock();
+ if (env_argv == nullptr) {
+ env_argv = new EnvArgv;
+ }
+ SetArgvFromEnv(env_argv); // a no-op if already initialized
+ bool result =
+ tensorflow::Flags::Parse(&env_argv->argc, &env_argv->argv[0], flag_list);
+ env_argv_mu.unlock();
+ return result;
+}
+
+// Testing only.
+// Reset the env_argv struct so that subsequent calls to ParseFlagsFromEnv()
+// will parse the environment variable (or the file it points to) anew, and set
+// *pargc, and *pargv to point to the internal locations of the argc and argv
+// constructed from the environment.
+void ResetFlagsFromEnvForTesting(int** pargc, std::vector<char*>** pargv) {
+ env_argv_mu.lock();
+ if (env_argv == nullptr) {
+ env_argv = new EnvArgv;
+ }
+ if (!env_argv->argv_save.empty()) {
+ for (int i = 0; env_argv->argv_save[i] != nullptr; i++) {
+ free(env_argv->argv_save[i]);
+ }
+ }
+ env_argv->initialized = false;
+ env_argv->argc = 0;
+ env_argv->argv.clear();
+ env_argv->argv_save.clear();
+ env_argv_mu.unlock();
+ *pargc = &env_argv->argc;
+ *pargv = &env_argv->argv;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h
new file mode 100644
index 0000000000..b54482ad2b
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_
+
+// This module exports ParseFlagsFromEnv(), which allows other modules to parse
+// flags from the environtment variable TF_XLA_FLAGS, or (if the first
+// non-whitespace in the variable value is not '-'), a file named by that
+// environment variable. The accepted syntax is that flags arguments are of
+// the form --flag=value or (for boolean flags) --flag, and are whitespace
+// separated. The <value> may be one of:
+// - <non-whitespace, non-nul not starting with single-quote or double-quote>
+// in which case the effective value is the string itself
+// - <single-quote><characters string not containing nul or
+// single-quote><single_quote> in which case the effective value is the
+// string with the single-quotes removed
+// - <double-quote><character string not containing nul or unesecaped
+// double-quote><double_quote> in which case the effective value if the
+// string with the double-quotes removed, and escaped sequences of
+// <backslash><char> replaced by <char>.
+//
+// Flags values inconsistent with the type of the flag will be rejected by the
+// flag parser.
+//
+// Examples:
+// TF_XLA_FLAGS="--foo=bar --wombat='value with a space'"
+//
+// TF_XLA_FLAGS=/tmp/flagfile
+// where /tmp/flagfile might contain
+// --some_flag="This is a string containing a \" and a '."
+// --another_flag=wombats
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Call tensorflow::Flags::Parse(argc, argv, flag_list) against any as yet
+// unrecognized flags passed in from the environment, and return its
+// return value.
+bool ParseFlagsFromEnv(const std::vector<tensorflow::Flag>& flag_list);
+
+// Used only for testing. Not to be used by clients.
+void ResetFlagsFromEnvForTesting(int** pargc, std::vector<char*>** pargv);
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_PARSE_FLAGS_FROM_ENV_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc
new file mode 100644
index 0000000000..7a966ce241
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/parse_flags_from_env_test.cc
@@ -0,0 +1,190 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Test for parse_flags_from_env.cc
+
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Test that XLA flags can be set from the environment.
+// Failure messages are accompanied by the text in msg[].
+static void TestParseFlagsFromEnv(const char* msg) {
+ // Initialize module under test.
+ int* pargc;
+ std::vector<char*>* pargv;
+ ResetFlagsFromEnvForTesting(&pargc, &pargv);
+
+ // Ensure that environment variable can be parsed when
+ // no flags are expected.
+ std::vector<tensorflow::Flag> empty_flag_list;
+ bool parsed_ok = ParseFlagsFromEnv(empty_flag_list);
+ CHECK(parsed_ok) << msg;
+ const std::vector<char*>& argv_first = *pargv;
+ CHECK_NE(argv_first[0], nullptr) << msg;
+ int i = 0;
+ while (argv_first[i] != nullptr) {
+ i++;
+ }
+ CHECK_EQ(i, *pargc) << msg;
+
+ // Check that actual flags can be parsed.
+ bool simple = false;
+ string with_value;
+ string embedded_quotes;
+ string single_quoted;
+ string double_quoted;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("simple", &simple, ""),
+ tensorflow::Flag("with_value", &with_value, ""),
+ tensorflow::Flag("embedded_quotes", &embedded_quotes, ""),
+ tensorflow::Flag("single_quoted", &single_quoted, ""),
+ tensorflow::Flag("double_quoted", &double_quoted, ""),
+ };
+ parsed_ok = ParseFlagsFromEnv(flag_list);
+ CHECK_EQ(*pargc, 1) << msg;
+ const std::vector<char*>& argv_second = *pargv;
+ CHECK_NE(argv_second[0], nullptr) << msg;
+ CHECK_EQ(argv_second[1], nullptr) << msg;
+ CHECK(parsed_ok) << msg;
+ CHECK(simple) << msg;
+ CHECK_EQ(with_value, "a_value") << msg;
+ CHECK_EQ(embedded_quotes, "single'double\"") << msg;
+ CHECK_EQ(single_quoted, "single quoted \\\\ \n \"") << msg;
+ CHECK_EQ(double_quoted, "double quoted \\ \n '\"") << msg;
+}
+
+// The flags settings to test.
+static const char kTestFlagString[] =
+ "--simple "
+ "--with_value=a_value "
+ "--embedded_quotes=single'double\" "
+ "--single_quoted='single quoted \\\\ \n \"' "
+ "--double_quoted=\"double quoted \\\\ \n '\\\"\" ";
+
+// Test that the environent variable is parserd correctly.
+TEST(ParseFlagsFromEnv, Basic) {
+ // Prepare environment.
+ setenv("TF_XLA_FLAGS", kTestFlagString, true /*overwrite*/);
+ TestParseFlagsFromEnv("(flags in environment variable)");
+}
+
+// Test that a file named by the environent variable is parserd correctly.
+TEST(ParseFlagsFromEnv, File) {
+ // environment variables where tmp dir may be specified.
+ static const char* kTempVars[] = {"TEST_TMPDIR", "TMP"};
+ static const char kTempDir[] = "/tmp"; // default temp dir if all else fails.
+ const char* tmp_dir = nullptr;
+ for (int i = 0; i != TF_ARRAYSIZE(kTempVars) && tmp_dir == nullptr; i++) {
+ tmp_dir = getenv(kTempVars[i]);
+ }
+ if (tmp_dir == nullptr) {
+ tmp_dir = kTempDir;
+ }
+ string tmp_file = tensorflow::strings::Printf("%s/parse_flags_from_env.%d",
+ tmp_dir, getpid());
+ FILE* fp = fopen(tmp_file.c_str(), "w");
+ CHECK_NE(fp, nullptr) << "can't write to " << tmp_file;
+ for (int i = 0; kTestFlagString[i] != '\0'; i++) {
+ putc(kTestFlagString[i], fp);
+ }
+ fflush(fp);
+ CHECK_EQ(ferror(fp), 0) << "writes failed to " << tmp_file;
+ fclose(fp);
+ // Prepare environment.
+ setenv("TF_XLA_FLAGS", tmp_file.c_str(), true /*overwrite*/);
+ TestParseFlagsFromEnv("(flags in file)");
+ unlink(tmp_file.c_str());
+}
+
+// Name of the test binary.
+static const char* binary_name;
+
+// Test that when we use both the environment variable and actual
+// commend line flags (when the latter is possible), the latter win.
+TEST(ParseFlagsFromEnv, EnvAndFlag) {
+ // TODO(m3b): convert to Subprocess when CL 137771604 is finished.
+ static struct {
+ const char* env;
+ const char* arg;
+ const char* expected_value;
+ } test[] = {
+ {nullptr, nullptr, "1\n"},
+ {nullptr, "--int_flag=2", "2\n"},
+ {"--int_flag=3", nullptr, "3\n"},
+ {"--int_flag=3", "--int_flag=2", "2\n"}, // flag beats environment
+ };
+ for (int i = 0; i != TF_ARRAYSIZE(test); i++) {
+ if (test[i].env != nullptr) {
+ setenv("TF_XLA_FLAGS", test[i].env, true /*overwrite*/);
+ }
+ tensorflow::SubProcess child;
+ std::vector<string> argv;
+ argv.push_back(binary_name);
+ argv.push_back("--recursing");
+ if (test[i].arg != nullptr) {
+ argv.push_back(test[i].arg);
+ }
+ child.SetProgram(binary_name, argv);
+ child.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE);
+ CHECK(child.Start()) << "test " << i;
+ string stdout_str;
+ int child_status = child.Communicate(nullptr, &stdout_str, nullptr);
+ CHECK_EQ(child_status, 0) << "test " << i;
+ CHECK_EQ(stdout_str, test[i].expected_value) << "test " << i;
+ }
+}
+
+} // namespace legacy_flags
+} // namespace xla
+
+int main(int argc, char* argv[]) {
+ // Save name of binary so that it may invoke itself.
+ xla::legacy_flags::binary_name = argv[0];
+ bool recursing = false;
+ xla::int32 int_flag = 1;
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("recursing", &recursing,
+ "Whether the binary is being invoked recusively."),
+ tensorflow::Flag("int_flag", &int_flag, "An integer flag to test with"),
+ };
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = xla::legacy_flags::ParseFlagsFromEnv(flag_list);
+ if (!parse_ok) {
+ LOG(QFATAL) << "can't parse from environment\n" << usage;
+ }
+ parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_ok) {
+ LOG(QFATAL) << usage;
+ }
+ if (recursing) {
+ printf("%d\n", int_flag);
+ exit(0);
+ }
+ testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.cc b/tensorflow/compiler/xla/legacy_flags/service_flags.cc
new file mode 100644
index 0000000000..41cb8d8bdf
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/service_flags.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's service module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static ServiceFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new ServiceFlags;
+ flags->xla_hlo_profile = false;
+ flags->xla_log_hlo_text = "";
+ flags->xla_generate_hlo_graph = "";
+ flags->xla_hlo_graph_addresses = false;
+ flags->xla_hlo_graph_layout = false;
+ flags->xla_hlo_graph_for_compute_constant = false;
+ flags->xla_dump_computations_to = "";
+ flags->xla_dump_hlo_text_to = "";
+ flags->xla_dump_executions_to = "";
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag(
+ "xla_hlo_profile", &flags->xla_hlo_profile,
+ "Instrument the computation to collect per-HLO cycle counts"),
+ tensorflow::Flag(
+ "xla_log_hlo_text", &flags->xla_log_hlo_text,
+ "If non-empty, print the text format of "
+ "HLO modules whose name partially matches this regex. E.g. "
+ "xla_log_hlo_text=.* will dump the text for every module."),
+ tensorflow::Flag(
+ "xla_generate_hlo_graph", &flags->xla_generate_hlo_graph,
+ "If non-empty, dump graph of HLO modules whose name partially "
+ "matches this regex. E.g. --xla_generate_hlo_graph=.* will dump "
+ "the graph of every module."),
+ tensorflow::Flag("xla_hlo_graph_addresses",
+ &flags->xla_hlo_graph_addresses,
+ "Show addresses of HLO ops in graph"),
+ tensorflow::Flag("xla_hlo_graph_layout", &flags->xla_hlo_graph_layout,
+ "Show layout of HLO ops in graph"),
+ tensorflow::Flag(
+ "xla_hlo_graph_for_compute_constant",
+ &flags->xla_hlo_graph_for_compute_constant,
+ "If true, include hlo dumps of graphs from ComputeConstant."
+ "Such graphs still need to be matched via xla_generate_hlo_graph."),
+ tensorflow::Flag("xla_dump_computations_to",
+ &flags->xla_dump_computations_to,
+ "Dumps computations that XLA executes into the provided "
+ "directory path"),
+ tensorflow::Flag("xla_dump_hlo_text_to", &flags->xla_dump_hlo_text_to,
+ "Dumps HLO modules that XLA executes into the provided "
+ "directory path"),
+ tensorflow::Flag("xla_dump_executions_to", &flags->xla_dump_executions_to,
+ "Dumps parameters and results of computations that XLA "
+ "executes into the provided directory path"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's service module.
+void AppendServiceFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the ServiceFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+ServiceFlags* GetServiceFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/service_flags.h b/tensorflow/compiler/xla/legacy_flags/service_flags.h
new file mode 100644
index 0000000000..d982506944
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/service_flags.h
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_
+
+// Legacy flags for XLA's service module.
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's service module.
+void AppendServiceFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's service module.
+typedef struct {
+ bool xla_hlo_profile; // Instrument the computation to collect per-HLO cycle
+ // counts
+ string xla_log_hlo_text; // If non-empty, print the text format of the HLO
+ // modules whose name partially
+ // matches this regex. E.g. xla_log_hlo_text=.*
+ // will dump the text for every module.
+ string xla_generate_hlo_graph; // If non-empty, dump graph of HLO modules
+ // whose name partially matches this regex.
+ // E.g. --xla_generate_hlo_graph=.* will dump
+ // the graph of every module.
+ bool xla_hlo_graph_addresses; // Show addresses of HLO ops in graph
+ bool xla_hlo_graph_layout; // Show layout of HLO ops in graph
+ bool xla_hlo_graph_for_compute_constant; // If true, include hlo dumps of
+ // graphs from ComputeConstant.
+ // Such graphs still need to be
+ // matched via
+ // xla_generate_hlo_graph.
+ string xla_dump_hlo_text_to; // Dumps HLO text for each HLO module that is
+ // executed into the provided directory path
+ string xla_dump_computations_to; // Dumps computations that XLA executes
+ // into the provided directory path
+ // Dumps parameters and results of computations that XLA executes into
+ // the provided directory path
+ string xla_dump_executions_to;
+} ServiceFlags;
+
+// Return a pointer to the ServiceFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+ServiceFlags* GetServiceFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_SERVICE_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc
new file mode 100644
index 0000000000..6506175777
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's stream_assignment module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static StreamAssignmentFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new StreamAssignmentFlags;
+ flags->xla_gpu_disable_multi_streaming = false;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_gpu_disable_multi_streaming",
+ &flags->xla_gpu_disable_multi_streaming,
+ "Disable multi-streaming in XLA's GPU backend"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's stream_assignment
+// module.
+void AppendStreamAssignmentFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the StreamAssignmentFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+StreamAssignmentFlags* GetStreamAssignmentFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h
new file mode 100644
index 0000000000..a98f9b3458
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_
+
+// Legacy flags for XLA's stream_assignment module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's stream_assignment
+// module.
+void AppendStreamAssignmentFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's stream_assignment module.
+typedef struct {
+ bool xla_gpu_disable_multi_streaming; // Disable multi-streaming in XLA's GPU
+ // backend
+} StreamAssignmentFlags;
+
+// Return a pointer to the StreamAssignmentFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+StreamAssignmentFlags* GetStreamAssignmentFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_STREAM_ASSIGNMENT_FLAGS_H_
diff --git a/tensorflow/compiler/xla/legacy_flags/util_flags.cc b/tensorflow/compiler/xla/legacy_flags/util_flags.cc
new file mode 100644
index 0000000000..e6df19ddd2
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/util_flags.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Legacy flags for XLA's util module.
+
+#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
+#include <vector>
+
+#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
+#include "tensorflow/compiler/xla/legacy_flags/util_flags.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Pointers to the parsed value of the flags and flag descriptors, initialized
+// via flags_init.
+static UtilFlags* flags;
+static std::vector<tensorflow::Flag>* flag_list;
+static std::once_flag flags_init;
+
+// Allocate *flags. Called via call_once(&flags_init,...).
+static void AllocateFlags() {
+ flags = new UtilFlags;
+ flags->xla_status_add_backtrace = false;
+ flag_list = new std::vector<tensorflow::Flag>({
+ tensorflow::Flag("xla_status_add_backtrace",
+ &flags->xla_status_add_backtrace,
+ "add backtraces to XLA-produced status values"),
+ });
+ ParseFlagsFromEnv(*flag_list);
+}
+
+// Append to *append_to flag definitions associated with XLA's util module.
+void AppendUtilFlags(std::vector<tensorflow::Flag>* append_to) {
+ std::call_once(flags_init, &AllocateFlags);
+ append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
+}
+
+// Return a pointer to the UtilFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+UtilFlags* GetUtilFlags() {
+ std::call_once(flags_init, &AllocateFlags);
+ return flags;
+}
+
+} // namespace legacy_flags
+} // namespace xla
diff --git a/tensorflow/compiler/xla/legacy_flags/util_flags.h b/tensorflow/compiler/xla/legacy_flags/util_flags.h
new file mode 100644
index 0000000000..03bffcd726
--- /dev/null
+++ b/tensorflow/compiler/xla/legacy_flags/util_flags.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_
+#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_
+
+// Legacy flags for the XLA's util module.
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace legacy_flags {
+
+// Append to *flag_list flag definitions associated with XLA's util module.
+void AppendUtilFlags(std::vector<tensorflow::Flag>* flag_list);
+
+// The values of flags associated with XLA's util module.
+typedef struct {
+ bool xla_status_add_backtrace; // add backtraces to XLA-produced statuses
+} UtilFlags;
+
+// Return a pointer to the UtilFlags struct;
+// repeated calls return the same pointer.
+// This should be called only after Flags::Parse() has returned.
+UtilFlags* GetUtilFlags();
+
+} // namespace legacy_flags
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_UTIL_FLAGS_H_
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
new file mode 100644
index 0000000000..f2b2bf8cec
--- /dev/null
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -0,0 +1,989 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_util.h"
+
+#include <algorithm>
+#include <limits>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+/* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
+ switch (primitive_type) {
+ case U8:
+ return *LiteralUtil::CreateR0<uint8>(0);
+ case U32:
+ return *LiteralUtil::CreateR0<uint32>(0);
+ case U64:
+ return *LiteralUtil::CreateR0<uint64>(0);
+ case S8:
+ return *LiteralUtil::CreateR0<int8>(0);
+ case S32:
+ return *LiteralUtil::CreateR0<int32>(0);
+ case S64:
+ return *LiteralUtil::CreateR0<int64>(0);
+ case F32:
+ return *LiteralUtil::CreateR0<float>(0);
+ case F64:
+ return *LiteralUtil::CreateR0<double>(0);
+ case PRED:
+ return *LiteralUtil::CreateR0<bool>(false);
+ case S16:
+ case U16:
+ LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case F16:
+ LOG(FATAL) << "f16 literals not yet implemented";
+ case TUPLE:
+ LOG(FATAL) << "tuple element type cannot take on value of 0";
+ case OPAQUE:
+ LOG(FATAL) << "opaque element type cannot take on value of 0";
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << primitive_type;
+ }
+}
+
+/* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
+ switch (primitive_type) {
+ case U8:
+ return *LiteralUtil::CreateR0<uint8>(1);
+ case U32:
+ return *LiteralUtil::CreateR0<uint32>(1);
+ case U64:
+ return *LiteralUtil::CreateR0<uint64>(1);
+ case S8:
+ return *LiteralUtil::CreateR0<int8>(1);
+ case S32:
+ return *LiteralUtil::CreateR0<int32>(1);
+ case S64:
+ return *LiteralUtil::CreateR0<int64>(1);
+ case F32:
+ return *LiteralUtil::CreateR0<float>(1);
+ case F64:
+ return *LiteralUtil::CreateR0<double>(1);
+ case PRED:
+ return *LiteralUtil::CreateR0<bool>(true);
+ case S16:
+ case U16:
+ LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case F16:
+ LOG(FATAL) << "f16 literals not yet implemented";
+ case TUPLE:
+ LOG(FATAL) << "tuple element type cannot take on value of 1";
+ case OPAQUE:
+ LOG(FATAL) << "opaque element type cannot take on value of 1";
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << primitive_type;
+ }
+}
+
+/* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
+ switch (primitive_type) {
+ case U8:
+ return *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
+ case U32:
+ return *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
+ case U64:
+ return *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
+ case S8:
+ return *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
+ case S32:
+ return *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
+ case S64:
+ return *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
+ case F32:
+ return *LiteralUtil::CreateR0<float>(
+ -std::numeric_limits<float>::infinity());
+ case F64:
+ return *LiteralUtil::CreateR0<double>(
+ -std::numeric_limits<double>::infinity());
+ case PRED:
+ return *LiteralUtil::CreateR0<bool>(false);
+ case S16:
+ case U16:
+ LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case F16:
+ LOG(FATAL) << "f16 literals not yet implemented";
+ case TUPLE:
+ LOG(FATAL) << "tuple element type has no minimum value";
+ case OPAQUE:
+ LOG(FATAL) << "opaque element type has no minimum value";
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << primitive_type;
+ }
+}
+
+/* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
+ switch (primitive_type) {
+ case U8:
+ return *LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
+ case U32:
+ return *LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
+ case U64:
+ return *LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
+ case S8:
+ return *LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
+ case S32:
+ return *LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
+ case S64:
+ return *LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
+ case F32:
+ return *LiteralUtil::CreateR0<float>(
+ std::numeric_limits<float>::infinity());
+ case F64:
+ return *LiteralUtil::CreateR0<double>(
+ std::numeric_limits<double>::infinity());
+ case PRED:
+ return *LiteralUtil::CreateR0<bool>(true);
+ case S16:
+ case U16:
+ LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case F16:
+ LOG(FATAL) << "f16 literals not yet implemented";
+ case TUPLE:
+ LOG(FATAL) << "tuple element type has no maximum value";
+ case OPAQUE:
+ LOG(FATAL) << "opaque element type has no maximum value";
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << primitive_type;
+ }
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+ const tensorflow::core::Bitmap& values) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR1(values, literal.get());
+ return literal;
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1U8(
+ tensorflow::StringPiece value) {
+ auto literal = MakeUnique<Literal>();
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())});
+ literal->set_u8s(value.ToString());
+ return literal;
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2F32Linspace(
+ float from, float to, int64 rows, int64 cols) {
+ auto value = MakeLinspaceArray2D(from, to, rows, cols);
+ return CreateR2FromArray2D(*value);
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::Relayout(
+ const Literal& original, const Layout& layout) {
+ // Note: if this were a performance bottleneck, we avoid cloning and just make
+ // an uninitialized array instead, since all values are clobbered below.
+ std::unique_ptr<Literal> result = CloneToUnique(original);
+ *result->mutable_shape()->mutable_layout() = layout;
+ const PrimitiveType primitive_type = original.shape().element_type();
+ switch (primitive_type) {
+ case F32:
+ LiteralUtil::EachCell<float>(
+ original,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
+ LiteralUtil::Set<float>(result.get(), indices, value);
+ });
+ return result;
+ case S32:
+ LiteralUtil::EachCell<int32>(
+ original,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, int32 value) {
+ LiteralUtil::Set<int32>(result.get(), indices, value);
+ });
+ return result;
+ case U32:
+ LiteralUtil::EachCell<uint32>(
+ original,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 value) {
+ LiteralUtil::Set<uint32>(result.get(), indices, value);
+ });
+ return result;
+ default:
+ LOG(FATAL) << "not yet implemented: "
+ << PrimitiveType_Name(primitive_type);
+ }
+}
+
+/* static */ StatusOr<std::unique_ptr<Literal>> LiteralUtil::Reshape(
+ const xla::Literal& input, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ if (ShapeUtil::IsTuple(input.shape())) {
+ return InvalidArgument("Reshape does not support tuples.");
+ }
+
+ if (!LayoutUtil::IsMonotonicWithDim0Major(input.shape().layout())) {
+ return Unimplemented(
+ "Input shape must have a monotonic layout where dimension 0 is major, "
+ "was: %s",
+ LayoutUtil::HumanString(input.shape().layout()).c_str());
+ }
+ std::vector<int64> layout(dimensions.size());
+ std::iota(layout.rbegin(), layout.rend(), 0);
+
+ // Because the layout is monotonic, we can simply reuse the same sequence of
+ // values without changing their order.
+ std::unique_ptr<Literal> output = CloneToUnique(input);
+ output->clear_shape();
+ output->mutable_shape()->set_element_type(input.shape().element_type());
+ for (int64 dimension : dimensions) {
+ output->mutable_shape()->add_dimensions(dimension);
+ }
+ *output->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(layout);
+
+ int64 elements_before = ShapeUtil::ElementsIn(input.shape());
+ int64 elements_after = ShapeUtil::ElementsIn(output->shape());
+ if (elements_before != elements_after) {
+ return InvalidArgument(
+ "Shapes before and after LiteralUtil::Reshape have different numbers "
+ "of elements: %s vs %s.",
+ ShapeUtil::HumanString(input.shape()).c_str(),
+ ShapeUtil::HumanString(output->shape()).c_str());
+ }
+ return std::move(output);
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::Transpose(
+ const Literal& original, tensorflow::gtl::ArraySlice<int64> permutation) {
+ CHECK(!ShapeUtil::IsTuple(original.shape()))
+ << "tuple is not supported for transpose";
+ std::vector<int64> dimension_numbers(ShapeUtil::Rank(original.shape()));
+ std::iota(dimension_numbers.begin(), dimension_numbers.end(), 0);
+ CHECK(std::is_permutation(permutation.begin(), permutation.end(),
+ dimension_numbers.begin()))
+ << "given permutation is not a permutation of dimension numbers";
+ std::vector<int64> new_dimension_sizes;
+ for (const int64 dim : permutation) {
+ new_dimension_sizes.push_back(original.shape().dimensions(dim));
+ }
+ const auto result_shape = ShapeUtil::MakeShape(
+ original.shape().element_type(), new_dimension_sizes);
+ std::unique_ptr<Literal> result = CloneToUnique(original);
+ *result->mutable_shape() = result_shape;
+ const PrimitiveType primitive_type = original.shape().element_type();
+ std::vector<int64> new_indices(ShapeUtil::Rank(original.shape()));
+ switch (primitive_type) {
+ case F32:
+ LiteralUtil::EachCell<float>(
+ original,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
+ for (int64 i = 0; i < permutation.size(); ++i) {
+ new_indices[i] = indices[permutation[i]];
+ }
+ LiteralUtil::Set<float>(result.get(), new_indices, value);
+ });
+ return result;
+ default:
+ LOG(FATAL) << "not yet implemented: "
+ << PrimitiveType_Name(primitive_type);
+ }
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::Slice(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) {
+ CHECK(!ShapeUtil::IsTuple(literal.shape()))
+ << "tuple is not supported for reshape";
+
+ std::vector<int64> result_dimensions;
+ for (int64 dnum = 0; dnum < ShapeUtil::Rank(literal.shape()); ++dnum) {
+ CHECK_GE(start_indices[dnum], 0);
+ CHECK_LE(limit_indices[dnum], literal.shape().dimensions(dnum));
+ int64 dimension = limit_indices[dnum] - start_indices[dnum];
+ CHECK_GT(dimension, 0);
+ result_dimensions.push_back(dimension);
+ }
+ const auto result_shape = ShapeUtil::MakeShapeWithLayout(
+ literal.shape().element_type(), result_dimensions,
+ AsInt64Slice(literal.shape().layout().minor_to_major()));
+
+ auto result_literal = MakeUnique<Literal>();
+ *result_literal->mutable_shape() = result_shape;
+ Reserve(ShapeUtil::ElementsIn(result_shape), result_literal.get());
+
+ std::vector<int64> new_indices(ShapeUtil::Rank(result_shape));
+ switch (result_shape.element_type()) {
+ case F32:
+ LiteralUtil::EachCell<float>(
+ *result_literal,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, float /*value*/) {
+ for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ float value = LiteralUtil::Get<float>(literal, new_indices);
+ LiteralUtil::Set<float>(result_literal.get(), indices, value);
+ });
+ return result_literal;
+ case S32:
+ LiteralUtil::EachCell<int32>(
+ *result_literal,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, int32 /*value*/) {
+ for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ int32 value = LiteralUtil::Get<int32>(literal, new_indices);
+ LiteralUtil::Set<int32>(result_literal.get(), indices, value);
+ });
+ return result_literal;
+ case U32:
+ LiteralUtil::EachCell<uint32>(
+ *result_literal,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 /*value*/) {
+ for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) {
+ new_indices[i] = indices[i] + start_indices[i];
+ }
+ uint32 value = LiteralUtil::Get<uint32>(literal, new_indices);
+ LiteralUtil::Set<uint32>(result_literal.get(), indices, value);
+ });
+ return result_literal;
+ default:
+ LOG(FATAL) << "not yet implemented: "
+ << PrimitiveType_Name(result_shape.element_type());
+ }
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::CloneToUnique(
+ const Literal& literal) {
+ auto unique = MakeUnique<Literal>();
+ *unique = literal;
+ return unique;
+}
+
+/* static */ string LiteralUtil::GetAsString(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ switch (literal.shape().element_type()) {
+ case PRED:
+ return Get<bool>(literal, multi_index) ? "true" : "false";
+ case U8:
+ return tensorflow::strings::StrCat(Get<uint8>(literal, multi_index));
+ case S32:
+ return tensorflow::strings::StrCat(Get<int32>(literal, multi_index));
+ case S64:
+ return tensorflow::strings::StrCat(Get<int64>(literal, multi_index));
+ case U32:
+ return tensorflow::strings::StrCat(Get<uint32>(literal, multi_index));
+ case U64:
+ return tensorflow::strings::StrCat(Get<uint64>(literal, multi_index));
+ case F32:
+ return tensorflow::strings::StrCat(Get<float>(literal, multi_index));
+ case F64:
+ return tensorflow::strings::StrCat(Get<double>(literal, multi_index));
+ default:
+ return tensorflow::strings::StrCat(
+ "[", PrimitiveType_Name(literal.shape().element_type()), "]");
+ }
+}
+
+/* static */ int64 LiteralUtil::LinearIndex(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal.shape(),
+ multi_index);
+}
+
+/* static */ string LiteralUtil::ToString(const Literal& literal) {
+ const Shape& shape = literal.shape();
+ std::vector<string> pieces;
+
+ auto element_to_string =
+ [&literal](tensorflow::gtl::ArraySlice<int64> indices) -> string {
+ PrimitiveType element_type = literal.shape().element_type();
+ if (element_type == PRED) {
+ // We display predicates in a densely packed form.
+ return Get<bool>(literal, indices) ? "1" : "0";
+ }
+ return ((!indices.empty() && indices.back() > 0) ? ", " : "") +
+ GetAsString(literal, indices);
+ };
+
+ // TODO(b/32894291): refactor this code to reduce code duplication.
+ if (ShapeUtil::IsTuple(shape)) {
+ pieces.push_back(ShapeUtil::HumanString(shape));
+ pieces.push_back(" (\n");
+ for (const auto& element_literal : literal.tuple_literals()) {
+ pieces.push_back(ToString(element_literal));
+ pieces.push_back(",\n");
+ }
+ pieces.push_back(")");
+ } else if (ShapeUtil::Rank(shape) == 0) {
+ pieces.push_back(GetAsString(literal, {}));
+ } else if (ShapeUtil::Rank(shape) == 1) {
+ pieces.push_back("{");
+ for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) {
+ pieces.push_back(element_to_string({i0}));
+ }
+ pieces.push_back("}");
+ } else if (ShapeUtil::Rank(shape) == 2) {
+ pieces.push_back(ShapeUtil::HumanString(shape));
+ pieces.push_back(" {\n");
+ for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) {
+ pieces.push_back(" { ");
+ for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) {
+ pieces.push_back(element_to_string({i0, i1}));
+ }
+ pieces.push_back(" ");
+ pieces.push_back("},\n");
+ }
+ pieces.push_back("}");
+ } else if (ShapeUtil::Rank(shape) == 3) {
+ pieces.push_back(ShapeUtil::HumanString(shape));
+ pieces.push_back(" {\n");
+ for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) {
+ pieces.push_back(i0 > 0 ? ",\n{" : "{");
+ for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) {
+ pieces.push_back(i1 > 0 ? ",\n { " : " { ");
+ for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) {
+ pieces.push_back(element_to_string({i0, i1, i2}));
+ }
+ pieces.push_back(" }");
+ }
+ pieces.push_back(" }");
+ }
+ pieces.push_back("\n}");
+ } else if (ShapeUtil::Rank(shape) == 4) {
+ pieces.push_back(ShapeUtil::HumanString(shape));
+ pieces.push_back(" {\n");
+ for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) {
+ pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0));
+ for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" { // i1=%lld\n", i1));
+ for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) {
+ pieces.push_back(" {");
+ for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) {
+ pieces.push_back(element_to_string({i0, i1, i2, i3}));
+ }
+ pieces.push_back("},\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back("}");
+ } else if (ShapeUtil::Rank(shape) == 5) {
+ pieces.push_back(ShapeUtil::HumanString(shape));
+ pieces.push_back(" {\n");
+ for (int64 i0 = 0; i0 < shape.dimensions(0); ++i0) {
+ pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0));
+ for (int64 i1 = 0; i1 < shape.dimensions(1); ++i1) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" { // i1=%lld\n", i1));
+ for (int64 i2 = 0; i2 < shape.dimensions(2); ++i2) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" { // i2=%lld\n", i2));
+ for (int64 i3 = 0; i3 < shape.dimensions(3); ++i3) {
+ pieces.push_back(" {");
+ for (int64 i4 = 0; i4 < shape.dimensions(4); ++i4) {
+ pieces.push_back(element_to_string({i0, i1, i2, i3, i4}));
+ }
+ pieces.push_back("},\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back(" },\n");
+ }
+ pieces.push_back("}");
+ } else {
+ pieces.push_back(ShapeUtil::HumanString(shape));
+ pieces.push_back(" {...}");
+ }
+
+ return tensorflow::str_util::Join(pieces, "");
+}
+
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeTuple(
+ tensorflow::gtl::ArraySlice<const Literal*> elements) {
+ auto literal = MakeUnique<Literal>();
+ std::vector<Shape> shape;
+ for (const Literal* tuple_element : elements) {
+ *literal->add_tuple_literals() = *tuple_element;
+ shape.push_back(tuple_element->shape());
+ }
+ *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape);
+ return literal;
+}
+
+/* static */ const void* LiteralUtil::InternalData(const Literal& literal) {
+ switch (literal.shape().element_type()) {
+ case PRED:
+ return reinterpret_cast<const void*>(literal.preds().data());
+ case U8:
+ return reinterpret_cast<const void*>(literal.u8s().data());
+ case S32:
+ return reinterpret_cast<const void*>(literal.s32s().data());
+ case S64:
+ return reinterpret_cast<const void*>(literal.s64s().data());
+ case U32:
+ return reinterpret_cast<const void*>(literal.u32s().data());
+ case U64:
+ return reinterpret_cast<const void*>(literal.u64s().data());
+ case F32:
+ return reinterpret_cast<const void*>(literal.f32s().data());
+ case F64:
+ return reinterpret_cast<const void*>(literal.f64s().data());
+ default:
+ LOG(FATAL) << "primitive type not supported in literals: "
+ << PrimitiveType_Name(literal.shape().element_type());
+ }
+}
+
+/* static */ void* LiteralUtil::MutableInternalData(Literal* literal) {
+ return const_cast<void*>(LiteralUtil::InternalData(*literal));
+}
+
+/* static */ void LiteralUtil::Reserve(int64 num_elements, Literal* literal) {
+ CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements);
+ switch (literal->shape().element_type()) {
+ case PRED:
+ GetMutableRepeatedField<bool>(literal)->Resize(num_elements, false);
+ break;
+ case U8:
+ // u8s is an optional "bytes", rather than a repeated field. Therefore its
+ // access methods are somewhat different from the others.
+ literal->mutable_u8s()->resize(num_elements, 0);
+ break;
+ case S32:
+ GetMutableRepeatedField<int32>(literal)->Resize(num_elements,
+ /*value=*/0);
+ break;
+ case S64:
+ GetMutableRepeatedField<tensorflow::protobuf_int64>(literal)->Resize(
+ num_elements,
+ /*value=*/0);
+ break;
+ case U32:
+ GetMutableRepeatedField<uint32>(literal)->Resize(num_elements,
+ /*value=*/0);
+ break;
+ case U64:
+ GetMutableRepeatedField<tensorflow::protobuf_uint64>(literal)->Resize(
+ num_elements,
+ /*value=*/0);
+ break;
+ case F32:
+ GetMutableRepeatedField<float>(literal)->Resize(num_elements,
+ /*value=*/0.0f);
+ break;
+ case F64:
+ GetMutableRepeatedField<double>(literal)->Resize(num_elements,
+ /*value=*/0.0);
+ break;
+ default:
+ LOG(FATAL) << "primitive type not supported in literals: "
+ << PrimitiveType_Name(literal->shape().element_type());
+ }
+}
+
+/* static */ tensorflow::Status LiteralUtil::ValidateLiteral(
+ const Literal& literal) {
+ TF_CHECK_OK(ShapeUtil::ValidateShape(literal.shape()));
+ int64 expected = ShapeUtil::ElementsIn(literal.shape());
+ int64 actual = -1;
+ switch (literal.shape().element_type()) {
+ case PRED:
+ actual = literal.preds().size();
+ break;
+ case U8:
+ actual = literal.u8s().size();
+ break;
+ case S32:
+ actual = literal.s32s_size();
+ break;
+ case U32:
+ actual = literal.u32s_size();
+ break;
+ case S64:
+ actual = literal.s64s_size();
+ break;
+ case U64:
+ actual = literal.u64s_size();
+ break;
+ case F32:
+ actual = literal.f32s_size();
+ break;
+ case F64:
+ actual = literal.f64s_size();
+ break;
+ default:
+ return tensorflow::errors::Unimplemented(
+ "unhandled element type for literal validation: " +
+ PrimitiveType_Name(literal.shape().element_type()));
+ }
+
+ if (expected != actual) {
+ return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf(
+ "literal has bad number of elements for its shape %s: want %lld "
+ "got %lld",
+ ShapeUtil::HumanString(literal.shape()).c_str(), expected, actual));
+ }
+
+ return tensorflow::Status::OK();
+}
+
+/* static */ void LiteralUtil::EachCellAsString(
+ const Literal& literal,
+ std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>
+ per_cell) {
+ if (ShapeUtil::Rank(literal.shape()) == 1) {
+ for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) {
+ per_cell({i0}, GetAsString(literal, {i0}));
+ }
+ return;
+ }
+
+ if (ShapeUtil::Rank(literal.shape()) == 2) {
+ for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) {
+ for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) {
+ per_cell({i0, i1}, GetAsString(literal, {i0, i1}));
+ }
+ }
+ return;
+ }
+
+ if (ShapeUtil::Rank(literal.shape()) == 3) {
+ for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) {
+ for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) {
+ for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) {
+ per_cell({i0, i1, i2}, GetAsString(literal, {i0, i1, i2}));
+ }
+ }
+ }
+ return;
+ }
+
+ if (ShapeUtil::Rank(literal.shape()) == 4) {
+ for (int64 i0 = 0; i0 < literal.shape().dimensions(0); ++i0) {
+ for (int64 i1 = 0; i1 < literal.shape().dimensions(1); ++i1) {
+ for (int64 i2 = 0; i2 < literal.shape().dimensions(2); ++i2) {
+ for (int64 i3 = 0; i3 < literal.shape().dimensions(3); ++i3) {
+ per_cell({i0, i1, i2, i3}, GetAsString(literal, {i0, i1, i2, i3}));
+ }
+ }
+ }
+ }
+ return;
+ }
+
+ LOG(FATAL) << "unhandled rank: " << ShapeUtil::Rank(literal.shape());
+}
+
+namespace {
+
+// Helper function which compares whether the elements of literal1 are equal to
+// the elements of literal2. Recursively iterates through the entire
+// multidimensional index space and compares the literal elements
+// one-by-one. literal1 and literal2 must be compatible (same dimensions and
+// type).
+template <typename NativeT>
+bool EqualElements(const Literal& literal1, const Literal& literal2,
+ int dimension, std::vector<int64>* multi_index) {
+ if (dimension == ShapeUtil::Rank(literal1.shape())) {
+ return (LiteralUtil::Get<NativeT>(literal1, *multi_index) ==
+ LiteralUtil::Get<NativeT>(literal2, *multi_index));
+ }
+ for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) {
+ (*multi_index)[dimension] = i;
+ if (!EqualElements<NativeT>(literal1, literal2, dimension + 1,
+ multi_index)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+/* static */ bool LiteralUtil::Equal(const Literal& literal1,
+ const Literal& literal2) {
+ if (!ShapeUtil::Compatible(literal1.shape(), literal2.shape())) {
+ return false;
+ }
+ if (ShapeUtil::IsTuple(literal1.shape())) {
+ // Because the shapes are compatible, they must have the same number of
+ // tuple elements.
+ CHECK_EQ(literal1.tuple_literals_size(), literal2.tuple_literals_size());
+ for (int i = 0; i < literal1.tuple_literals_size(); ++i) {
+ if (!Equal(literal1.tuple_literals(i), literal2.tuple_literals(i))) {
+ return false;
+ }
+ }
+ return true;
+ } else {
+ std::vector<int64> multi_index(ShapeUtil::Rank(literal1.shape()), 0);
+ switch (literal1.shape().element_type()) {
+ case PRED:
+ return EqualElements<bool>(literal1, literal2, 0, &multi_index);
+ case U8:
+ return EqualElements<uint8>(literal1, literal2, 0, &multi_index);
+ case S32:
+ return EqualElements<int32>(literal1, literal2, 0, &multi_index);
+ case S64:
+ return EqualElements<int64>(literal1, literal2, 0, &multi_index);
+ case U32:
+ return EqualElements<uint32>(literal1, literal2, 0, &multi_index);
+ case U64:
+ return EqualElements<uint64>(literal1, literal2, 0, &multi_index);
+ case F32:
+ return EqualElements<float>(literal1, literal2, 0, &multi_index);
+ case F64:
+ return EqualElements<double>(literal1, literal2, 0, &multi_index);
+ default:
+ LOG(FATAL) << "Unimplemented: LiteralUtil::Equal for type "
+ << PrimitiveType_Name(literal1.shape().element_type());
+ }
+ }
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<bool> LiteralUtil::GetArraySlice<bool>(
+ const Literal& literal) {
+ CHECK(literal.shape().element_type() == PRED);
+ return literal.preds();
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<bool>*
+LiteralUtil::GetMutableRepeatedField<bool>(Literal* literal) {
+ CHECK(literal->shape().element_type() == PRED);
+ return literal->mutable_preds();
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<uint32>
+LiteralUtil::GetArraySlice<uint32>(const Literal& literal) {
+ CHECK(literal.shape().element_type() == U32);
+ return literal.u32s();
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<uint32>*
+LiteralUtil::GetMutableRepeatedField<uint32>(Literal* literal) {
+ CHECK(literal->shape().element_type() == U32);
+ return literal->mutable_u32s();
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<uint64>
+LiteralUtil::GetArraySlice<uint64>(const Literal& literal) {
+ CHECK(literal.shape().element_type() == U64);
+ return AsUInt64Slice(literal.u64s());
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>*
+LiteralUtil::GetMutableRepeatedField<tensorflow::protobuf_uint64>(
+ Literal* literal) {
+ CHECK(literal->shape().element_type() == U64);
+ return literal->mutable_u64s();
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<int32>
+LiteralUtil::GetArraySlice<int32>(const Literal& literal) {
+ CHECK(literal.shape().element_type() == S32);
+ return literal.s32s();
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<int32>*
+LiteralUtil::GetMutableRepeatedField<int32>(Literal* literal) {
+ CHECK(literal->shape().element_type() == S32);
+ return literal->mutable_s32s();
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<int64>
+LiteralUtil::GetArraySlice<int64>(const Literal& literal) {
+ CHECK(literal.shape().element_type() == S64);
+ return AsInt64Slice(literal.s64s());
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+LiteralUtil::GetMutableRepeatedField<tensorflow::protobuf_int64>(
+ Literal* literal) {
+ CHECK(literal->shape().element_type() == S64);
+ return literal->mutable_s64s();
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<float>
+LiteralUtil::GetArraySlice<float>(const Literal& literal) {
+ CHECK(literal.shape().element_type() == F32);
+ return literal.f32s();
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<float>*
+LiteralUtil::GetMutableRepeatedField<float>(Literal* literal) {
+ CHECK(literal->shape().element_type() == F32);
+ return literal->mutable_f32s();
+}
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<double>
+LiteralUtil::GetArraySlice<double>(const Literal& literal) {
+ CHECK(literal.shape().element_type() == F64);
+ return literal.f64s();
+}
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<double>*
+LiteralUtil::GetMutableRepeatedField<double>(Literal* literal) {
+ CHECK(literal->shape().element_type() == F64);
+ return literal->mutable_f64s();
+}
+
+template <typename NativeT>
+static bool AllElementsEqualValue(const Literal& literal, NativeT value) {
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+ auto multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+ if (LiteralUtil::Get<NativeT>(literal, multi_index) != value) {
+ return false;
+ }
+ }
+ return true;
+}
+
+/* static */ bool LiteralUtil::IsAll(const Literal& literal, int8 value) {
+ switch (literal.shape().element_type()) {
+ case U8:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint8>(literal, value);
+ }
+ return false;
+ case U32:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint32>(literal, value);
+ }
+ return false;
+ case U64:
+ if (value >= 0) {
+ return AllElementsEqualValue<uint64>(literal, value);
+ }
+ return false;
+ case S8:
+ return AllElementsEqualValue<int8>(literal, value);
+ case S32:
+ return AllElementsEqualValue<int32>(literal, value);
+ case S64:
+ return AllElementsEqualValue<int64>(literal, value);
+ case F32:
+ return AllElementsEqualValue<float>(literal, value);
+ case F64:
+ return AllElementsEqualValue<double>(literal, value);
+ case PRED:
+ if (value == 0) {
+ return AllElementsEqualValue<bool>(literal, false);
+ }
+ if (value == 1) {
+ return AllElementsEqualValue<bool>(literal, true);
+ }
+ return false;
+ default:
+ return false;
+ }
+}
+
+/* static */ bool LiteralUtil::IsZero(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> indices) {
+ switch (literal.shape().element_type()) {
+ case U8:
+ return Get<uint8>(literal, indices) == 0;
+ case U32:
+ return Get<uint32>(literal, indices) == 0;
+ case U64:
+ return Get<uint64>(literal, indices) == 0;
+ case S8:
+ return Get<int8>(literal, indices) == 0;
+ case S32:
+ return Get<int32>(literal, indices) == 0;
+ case S64:
+ return Get<int64>(literal, indices) == 0;
+ case F32:
+ return Get<float>(literal, indices) == 0.0f;
+ case F64:
+ return Get<double>(literal, indices) == 0.0;
+ case PRED:
+ return Get<bool>(literal, indices) == false;
+ default:
+ LOG(FATAL) << "Input literal must be an array.";
+ }
+}
+
+template <>
+/* static */ void LiteralUtil::PopulateWithValue(
+ int64 value, tensorflow::gtl::ArraySlice<int64> dimensions,
+ Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<int64>(), dimensions);
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+ repeated_field =
+ GetMutableRepeatedField<tensorflow::protobuf_int64>(literal);
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) {
+ repeated_field->Add(value);
+ }
+}
+
+template <>
+/* static */ void LiteralUtil::PopulateWithValue(
+ uint64 value, tensorflow::gtl::ArraySlice<int64> dimensions,
+ Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<uint64>(), dimensions);
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>*
+ repeated_field =
+ GetMutableRepeatedField<tensorflow::protobuf_uint64>(literal);
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) {
+ repeated_field->Add(value);
+ }
+}
+
+template <>
+/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value,
+ Literal* literal) {
+ CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements);
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+ repeated_field =
+ GetMutableRepeatedField<tensorflow::protobuf_int64>(literal);
+ repeated_field->Resize(num_elements, value);
+}
+
+template <>
+/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value,
+ Literal* literal) {
+ CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements);
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>*
+ repeated_field =
+ GetMutableRepeatedField<tensorflow::protobuf_uint64>(literal);
+ repeated_field->Resize(num_elements, value);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
new file mode 100644
index 0000000000..78e9e3fb24
--- /dev/null
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -0,0 +1,1004 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utilities for dealing with Literal protobufs.
+
+#ifndef TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
+
+#include <functional>
+#include <initializer_list>
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Utility class for dealing with XLA literal values. Most methods are
+// templated by native (host) type which corresponds to a unique XLA
+// PrimitiveType. See ComputationBuilder for details. Not all primitive types
+// defined in xla_data.proto have a corresponding native type or even have a
+// storage location in the Literal proto yet (for example, primitive type F16).
+class LiteralUtil {
+ public:
+ // Create new literal of a given rank. To minimize ambiguity (for users and
+ // the compiler) these CreateR[0-2] methods should explicitly specify the
+ // native type. For example:
+ //
+ // CreateR1<float>({1.0, 42.0});
+ // CreateR2<uint32>({{1, 2}, {3, 4}});
+ //
+ // The variants not ending with WithLayout use the default XLA layout for the
+ // literal's linear representation in memory.
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR0(NativeT value);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR1(
+ tensorflow::gtl::ArraySlice<NativeT> values);
+ static std::unique_ptr<Literal> CreateR1(
+ const tensorflow::core::Bitmap& values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR2WithLayout(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ const Layout& layout);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR3(
+ std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR3WithLayout(
+ std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ values,
+ const Layout& layout);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR4(
+ std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>>
+ values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR4WithLayout(
+ std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>>
+ values,
+ const Layout& layout);
+
+ // Creates a new value that has the equivalent value as literal, but conforms
+ // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major
+ // dimension layout can be re-layed-out as {1, 0} minor-to-major dimension
+ // layout and the value in the cell at any given logical index (i0, i1) will
+ // be the same.
+ //
+ // Note: this is useful when the client wants to ensure that a value placed in
+ // the XLA allocation tracker has a particular layout; for efficiency
+ // purposes or avoiding unimplemented operation/layout combinations.
+ static std::unique_ptr<Literal> Relayout(const Literal& literal,
+ const Layout& new_layout);
+
+ // Reshapes literal 'input' to have 'shape'. Both the original shape and
+ // 'shape' must contain the same number of elements. The implementation
+ // currently only supports monotonic dim0-major layouts.
+ static StatusOr<std::unique_ptr<Literal>> Reshape(
+ const xla::Literal& input, tensorflow::gtl::ArraySlice<int64> shape);
+
+ // Creates a new literal by reordering the dimensions of the original literal.
+ // The given `permutation` must be a permutation of the dimension numbers
+ // in the original literal, and it specifies the order of the new dimensions
+ // in the result literal (i.e., new_order[i] = old_order[permutation[i]]).
+ // For example, a transpose call on a literal of shape [3 x 8 x 4] and
+ // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8].
+ static std::unique_ptr<Literal> Transpose(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> permutation);
+
+ // Creates a sub-array from the the given literal by extracting the indices
+ // [start_index, limit_index) of each dimension. The result literal has the
+ // same rank and layout as for the given literal. The number of indices in
+ // start_indices and limit_indices must be the rank of the literal, and the
+ // indices follow the order of the dimensions.
+ static std::unique_ptr<Literal> Slice(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices);
+
+ // Create a literal by converting each element in an original literal to a new
+ // type.
+ template <typename NativeSrcT, typename NativeDestT>
+ static std::unique_ptr<Literal> Convert(const Literal& literal);
+
+ // Create a literal value zero of the given primitive type.
+ static Literal Zero(PrimitiveType primitive_type);
+
+ // Create a literal value one of the given primitive type.
+ static Literal One(PrimitiveType primitive_type);
+
+ // Creates a literal value containing the minimum value of the given
+ // primitive type. For floating-point types, returns -inf.
+ static Literal MinValue(PrimitiveType primitive_type);
+
+ // Create a literal value containing the maximum value of the given
+ // primitive type. For floating-point types, returns inf.
+ static Literal MaxValue(PrimitiveType primitive_type);
+
+ // Create a literal of the given shape where each element is `value`.
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateFullWithMonotonicDim0MajorLayout(
+ tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value);
+
+ // Create a new literal from an array. The variants not ending with WithLayout
+ // use the default XLA layout for the literal's linear representation in
+ // memory.
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR2FromArray2D(
+ const Array2D<NativeT>& values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR3FromArray3D(
+ const Array3D<NativeT>& values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR4FromArray4D(
+ const Array4D<NativeT>& values);
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout);
+
+ // Creates a new vector of U8s literal value from a string.
+ static std::unique_ptr<Literal> CreateR1U8(tensorflow::StringPiece value);
+
+ // Creates a linspace-populated literal with the given number of rows and
+ // columns.
+ static std::unique_ptr<Literal> CreateR2F32Linspace(float from, float to,
+ int64 rows, int64 cols);
+
+ // Creates a literal that projects the (x, y) dimensions given in values into
+ // the z dimension given by "projection".
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR3Projected(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ int64 projection);
+
+ // Creates a literal that projects the (x, y) dimensions given in values into
+ // the z and p dimensions given.
+ template <typename NativeT>
+ static std::unique_ptr<Literal> CreateR4Projected(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ int64 projection_p, int64 projection_z);
+
+ // Clones literal into an owned unique_ptr version.
+ static std::unique_ptr<Literal> CloneToUnique(const Literal& literal);
+
+ // Gets or sets an element in the literal at the given index. The index is
+ // CHECKed against the dimension sizes.
+ template <typename NativeT>
+ static NativeT Get(const Literal& literal,
+ tensorflow::gtl::ArraySlice<int64> multi_index);
+ template <typename NativeT>
+ static void Set(Literal* literal,
+ tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value);
+
+ // Returns the element value at index (0, ..., 0), however many zeroes are
+ // required for that index.
+ template <typename NativeT>
+ static NativeT GetFirstElement(const Literal& literal);
+
+ // As Get(), but determines the correct type and converts the value
+ // into text.
+ static string GetAsString(const Literal& literal,
+ tensorflow::gtl::ArraySlice<int64> multi_index);
+
+ // Returns an identity matrix (rank 2) with the given row and column count.
+ template <typename NativeT>
+ static std::unique_ptr<Literal> MakeIdentityR2(int64 size);
+
+ // Returns a tuple literal composed of given literals.
+ static std::unique_ptr<Literal> MakeTuple(
+ tensorflow::gtl::ArraySlice<const Literal*> elements);
+
+ // Validates that the data payload of the literal matches the literal shape;
+ // if it does not, an appropriate status is returned.
+ static tensorflow::Status ValidateLiteral(const Literal& literal);
+
+ // Returns a string representation of the literal value.
+ static string ToString(const Literal& literal);
+
+ // Invokes the "per cell" callback for each element in the provided
+ // literal with the element's indices and a string representation of
+ // the element's value.
+ //
+ // This function is useful if you want a polymorphic representation
+ // of the tensor's elements (turning it to a string for something
+ // like representation in a protobuf).
+ static void EachCellAsString(
+ const Literal& literal,
+ std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value)>
+ per_cell);
+ template <typename NativeT>
+ static void EachCell(
+ const Literal& literal,
+ std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell);
+
+ // Templated methods which populate the given repeated field in the Literal
+ // proto with the given value(s). The Shape field of the Literal proto is set
+ // to match the array dimensions and type. Examples:
+ //
+ // // Populate with floats.
+ // Array2D<float> float_values = ...
+ // PopulateR2FromArray2D(values, literal);
+ //
+ // // Populate with int32s.
+ // PopulateR2({{1, 2}, {3, 4}}, literal);
+ //
+ template <typename NativeT>
+ static void PopulateR0(NativeT values, Literal* literal);
+ template <typename NativeT>
+ static void PopulateR1(tensorflow::gtl::ArraySlice<NativeT> values,
+ Literal* literal);
+ static void PopulateR1(const tensorflow::core::Bitmap& values,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR2WithLayout(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ const Layout& layout, Literal* literal);
+ template <typename NativeT>
+ static void PopulateR2FromArray2D(const Array2D<NativeT>& values,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR3FromArray3D(const Array3D<NativeT>& values,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR4FromArray4D(const Array4D<NativeT>& values,
+ Literal* literal);
+ template <typename NativeT>
+ static void PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout,
+ Literal* literal);
+
+ // Creates a Literal of the given dimensions with all elements set to the
+ // given value.
+ template <typename NativeT>
+ static void PopulateWithValue(NativeT value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ Literal* literal);
+
+ // Returns a pointer to the underlying buffer in the protobuf containing the
+ // array data. Use with care.
+ static const void* InternalData(const Literal& literal);
+ static void* MutableInternalData(Literal* literal);
+
+ // Allocates space in the repeated_field of the literal sufficient to hold
+ // num_elements of the literal's primitive type. Values in the buffer are set
+ // to zero. num_elements must equal the number of elements in the literals
+ // shape.
+ static void Reserve(int64 num_elements, Literal* literal);
+
+ // Allocates space in the repeated_field of the literal sufficient to hold
+ // num_elements of the literal's primitive type and sets each element in the
+ // literal to the given value. num_elements must equal the number of elements
+ // in the literals shape.
+ template <typename NativeT>
+ static void Resize(int64 num_elements, NativeT value, Literal* literal);
+
+ // Returns true if the two given literals have the same shape and
+ // values. Layout is not considered in the comparison.
+ static bool Equal(const Literal& literal1, const Literal& literal2);
+
+ // Returns whether every element in the given literal is equal to value.
+ //
+ // value is an int8 because we expect this to be called with small
+ // compile-time constants (0, -1, etc.) and so that whatever value you pass
+ // can be represented exactly by floating-point types as small as 16 bits.
+ //
+ // If value doesn't fit in literal's type, returns false. Values of 1/0 are
+ // considered equal to true/false; other values are not considered equal to
+ // true.
+ static bool IsAll(const Literal& literal, int8 value);
+
+ // Returns whether the literal is zero at the specified index. The literal
+ // must be an array.
+ static bool IsZero(const Literal& literal,
+ tensorflow::gtl::ArraySlice<int64> indices);
+
+ private:
+ // Returns an ArraySlice view of the array for the given literal for the
+ // given NativeT (e.g., float). These
+ // functions map native type to XLA PrimitiveType via template
+ // specialization. The unspecialized forms below aborts to handle the error
+ // case where the given native type does not map to an XLA primitive type.
+ template <typename NativeT>
+ static tensorflow::gtl::ArraySlice<NativeT> GetArraySlice(
+ const Literal& literal) {
+ static_assert(!std::is_same<NativeT, NativeT>::value,
+ "Cannot map native type to primitive type.");
+ }
+ template <typename NativeT>
+ static tensorflow::protobuf::RepeatedField<NativeT>* GetMutableRepeatedField(
+ Literal* literal) {
+ // Make the expression depend on the template parameter NativeT so
+ // that this compile-time error only apperas if this function is
+ // instantiated with some concrete type that is not specialized
+ // below.
+ static_assert(!std::is_same<NativeT, NativeT>::value,
+ "Cannot map native type to primitive type.");
+ }
+
+ // Returns the linear index of the given index within the literal's
+ // element_type repeated field.
+ static int64 LinearIndex(const Literal& literal,
+ tensorflow::gtl::ArraySlice<int64> multi_index);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil);
+};
+
+// Declarations of template specializations for GetArraySlice and
+// GetMutableRepeatedField. The specializations map native type to XLA primitive
+// type.
+template <>
+/* static */ tensorflow::gtl::ArraySlice<bool> LiteralUtil::GetArraySlice<bool>(
+ const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<bool>*
+LiteralUtil::GetMutableRepeatedField<bool>(Literal* literal);
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<uint32>
+LiteralUtil::GetArraySlice<uint32>(const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<uint32>*
+LiteralUtil::GetMutableRepeatedField<uint32>(Literal* literal);
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<uint64>
+LiteralUtil::GetArraySlice<uint64>(const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>*
+LiteralUtil::GetMutableRepeatedField<tensorflow::protobuf_uint64>(
+ Literal* literal);
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<int32>
+LiteralUtil::GetArraySlice<int32>(const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<int32>*
+LiteralUtil::GetMutableRepeatedField<int32>(Literal* literal);
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<int64>
+LiteralUtil::GetArraySlice<int64>(const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>*
+LiteralUtil::GetMutableRepeatedField<tensorflow::protobuf_int64>(
+ Literal* literal);
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<float>
+LiteralUtil::GetArraySlice<float>(const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<float>*
+LiteralUtil::GetMutableRepeatedField<float>(Literal* literal);
+
+template <>
+/* static */ tensorflow::gtl::ArraySlice<double>
+LiteralUtil::GetArraySlice<double>(const Literal& literal);
+
+template <>
+/* static */ tensorflow::protobuf::RepeatedField<double>*
+LiteralUtil::GetMutableRepeatedField<double>(Literal* literal);
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR0(NativeT value) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR0(value, literal.get());
+ return literal;
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR1(
+ tensorflow::gtl::ArraySlice<NativeT> values) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR1(values, literal.get());
+ return literal;
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2WithLayout(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ const Layout& layout) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR2WithLayout(values, layout, literal.get());
+ return literal;
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values) {
+ return CreateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3WithLayout(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ values,
+ const Layout& layout) {
+ const int64 d0 = values.size();
+ const int64 d1 = values.begin()->size();
+ const int64 d2 = values.begin()->begin()->size();
+ Array3D<NativeT> tmp(d0, d1, d2);
+ int64 i0 = 0;
+ for (auto d1_values : values) {
+ int64 i1 = 0;
+ for (auto d2_values : d1_values) {
+ int64 i2 = 0;
+ for (auto value : d2_values) {
+ tmp(i0, i1, i2) = value;
+ ++i2;
+ }
+ ++i1;
+ }
+ ++i0;
+ }
+ return CreateR3FromArray3DWithLayout(tmp, layout);
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ values) {
+ return CreateR3WithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4WithLayout(
+ std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>>
+ values,
+ const Layout& layout) {
+ const int64 d0 = values.size();
+ const int64 d1 = values.begin()->size();
+ const int64 d2 = values.begin()->begin()->size();
+ const int64 d3 = values.begin()->begin()->begin()->size();
+ Array4D<NativeT> tmp(d0, d1, d2, d3);
+ int64 i0 = 0;
+ for (auto d1_values : values) {
+ int64 i1 = 0;
+ for (auto d2_values : d1_values) {
+ int64 i2 = 0;
+ for (auto d3_values : d2_values) {
+ int64 i3 = 0;
+ for (auto value : d3_values) {
+ tmp(i0, i1, i2, i3) = value;
+ ++i3;
+ }
+ ++i2;
+ }
+ ++i1;
+ }
+ ++i0;
+ }
+ return CreateR4FromArray4DWithLayout(tmp, layout);
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4(
+ std::initializer_list<std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>>
+ values) {
+ return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
+ const Layout& layout) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR2FromArray2DWithLayout(values, layout, literal.get());
+ return literal;
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR2FromArray2D(
+ const Array2D<NativeT>& values) {
+ return CreateR2FromArray2DWithLayout(values,
+ LayoutUtil::GetDefaultLayoutForR2());
+}
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
+ const Layout& layout) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR3FromArray3DWithLayout(values, layout, literal.get());
+ return literal;
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3FromArray3D(
+ const Array3D<NativeT>& values) {
+ return CreateR3FromArray3DWithLayout(values,
+ LayoutUtil::GetDefaultLayoutForR3());
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR3Projected(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ int64 projection) {
+ int64 dim0_size = projection;
+ int64 dim1_size = values.size();
+ int64 dim2_size = values.begin()->size();
+
+ Array3D<NativeT> array(dim0_size, dim1_size, dim2_size);
+ for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
+ int64 dim1 = 0;
+ for (auto inner_list : values) {
+ int64 dim2 = 0;
+ for (auto value : inner_list) {
+ array(dim0, dim1, dim2) = value;
+ ++dim2;
+ }
+ CHECK_EQ(dim2_size, dim2);
+ ++dim1;
+ }
+ CHECK_EQ(dim1_size, dim1);
+ }
+ return CreateR3FromArray3D(array);
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4Projected(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ int64 projection_p, int64 projection_z) {
+ int64 dim0_size = projection_p;
+ int64 dim1_size = projection_z;
+ int64 dim2_size = values.size();
+ int64 dim3_size = values.begin()->size();
+
+ Array4D<NativeT> array(dim0_size, dim1_size, dim2_size, dim3_size);
+ for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
+ for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
+ int64 dim2 = 0;
+ for (auto inner_list : values) {
+ int64 dim3 = 0;
+ for (auto value : inner_list) {
+ array(dim0, dim1, dim2, dim3) = value;
+ ++dim3;
+ }
+ CHECK_EQ(dim3_size, dim3);
+ ++dim2;
+ }
+ CHECK_EQ(dim2_size, dim2);
+ }
+ }
+ return CreateR4FromArray4D(array);
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::CreateR4FromArray4D(
+ const Array4D<NativeT>& values) {
+ return CreateR4FromArray4DWithLayout(values,
+ LayoutUtil::GetDefaultLayoutForR4());
+}
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
+ const Layout& layout) {
+ auto literal = MakeUnique<Literal>();
+ PopulateR4FromArray4DWithLayout(values, layout, literal.get());
+ return literal;
+}
+
+template <typename NativeT>
+/* static */ NativeT LiteralUtil::Get(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ int64 linear_index = LinearIndex(literal, multi_index);
+ return GetArraySlice<NativeT>(literal).at(linear_index);
+}
+
+template <typename NativeT>
+/* static */ NativeT LiteralUtil::GetFirstElement(const Literal& literal) {
+ return GetArraySlice<NativeT>(literal).at(0);
+}
+
+template <>
+/* static */ inline uint8 LiteralUtil::Get<uint8>(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ CHECK(literal.shape().element_type() == U8);
+ int64 linear_index = LinearIndex(literal, multi_index);
+ return literal.u8s()[linear_index];
+}
+
+template <>
+/* static */ inline int8 LiteralUtil::Get<int8>(
+ const Literal& literal, tensorflow::gtl::ArraySlice<int64> multi_index) {
+ CHECK(literal.shape().element_type() == S8);
+ int64 linear_index = LinearIndex(literal, multi_index);
+ return literal.u8s()[linear_index];
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::Set(
+ Literal* literal, tensorflow::gtl::ArraySlice<int64> multi_index,
+ NativeT value) {
+ int64 linear_index = LinearIndex(*literal, multi_index);
+ GetMutableRepeatedField<NativeT>(literal)->Set(linear_index, value);
+}
+
+template <>
+/* static */ inline void LiteralUtil::Set(
+ Literal* literal, tensorflow::gtl::ArraySlice<int64> multi_index,
+ uint8 value) {
+ int64 linear_index = LinearIndex(*literal, multi_index);
+ (*literal->mutable_u8s())[linear_index] = value;
+}
+
+template <>
+/* static */ inline void LiteralUtil::Set(
+ Literal* literal, tensorflow::gtl::ArraySlice<int64> multi_index,
+ int8 value) {
+ return Set<uint8>(literal, multi_index, value);
+}
+
+template <>
+/* static */ inline void LiteralUtil::Set(
+ Literal* literal, tensorflow::gtl::ArraySlice<int64> multi_index,
+ int64 value) {
+ int64 linear_index = LinearIndex(*literal, multi_index);
+ (*literal->mutable_s64s())[linear_index] = value;
+}
+
+template <>
+/* static */ inline void LiteralUtil::Set(
+ Literal* literal, tensorflow::gtl::ArraySlice<int64> multi_index,
+ uint64 value) {
+ int64 linear_index = LinearIndex(*literal, multi_index);
+ (*literal->mutable_u64s())[linear_index] = value;
+}
+
+// Returns an identity matrix (rank 2) with the given row and column count.
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::MakeIdentityR2(int64 size) {
+ Array2D<NativeT> array(size, size, 0);
+ for (int64 i = 0; i < size; ++i) {
+ array(i, i) = 1;
+ }
+ return CreateR2FromArray2D(array);
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::EachCell(
+ const Literal& literal,
+ std::function<void(tensorflow::gtl::ArraySlice<int64> indices,
+ NativeT value)>
+ per_cell) {
+ if (ShapeUtil::HasZeroElements(literal.shape())) {
+ return;
+ }
+ std::vector<int64> indices(ShapeUtil::Rank(literal.shape()), 0);
+ do {
+ per_cell(indices, Get<NativeT>(literal, indices));
+ } while (IndexUtil::BumpIndices(literal.shape(), &indices));
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR0(NativeT value, Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), {});
+ tensorflow::protobuf::RepeatedField<NativeT>* repeated_field =
+ GetMutableRepeatedField<NativeT>(literal);
+ repeated_field->Add(value);
+}
+
+template <>
+/* static */ inline void LiteralUtil::PopulateR0<uint8>(uint8 value,
+ Literal* literal) {
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<uint8>(), {});
+ literal->mutable_u8s()->push_back(value);
+}
+
+template <>
+/* static */ inline void LiteralUtil::PopulateR0<int8>(int8 value,
+ Literal* literal) {
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<int8>(), {});
+ literal->mutable_u8s()->push_back(value);
+}
+
+template <>
+/* static */ inline void LiteralUtil::PopulateR0<uint64>(uint64 value,
+ Literal* literal) {
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<uint64>(), {});
+ literal->mutable_u64s()->Add(value);
+}
+
+template <>
+/* static */ inline void LiteralUtil::PopulateR0<int64>(int64 value,
+ Literal* literal) {
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<int64>(), {});
+ literal->mutable_s64s()->Add(value);
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR1(
+ tensorflow::gtl::ArraySlice<NativeT> values, Literal* literal) {
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
+ {static_cast<int64>(values.size())});
+ Reserve(values.size(), literal);
+ for (int64 i = 0; i < values.size(); ++i) {
+ Set(literal, {i}, values[i]);
+ }
+}
+
+/* static */ inline void LiteralUtil::PopulateR1(
+ const tensorflow::core::Bitmap& values, Literal* literal) {
+ *literal->mutable_shape() =
+ ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())});
+ Reserve(values.bits(), literal);
+ for (int64 i = 0; i < values.bits(); ++i) {
+ Set(literal, {i}, values.get(i));
+ }
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR2WithLayout(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ const Layout& layout, Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(),
+ {static_cast<int64>(values.size()),
+ static_cast<int64>(values.begin()->size())},
+ AsInt64Slice(layout.minor_to_major()));
+
+ const int64 dim0_size = values.size();
+ const int64 dim1_size = values.begin()->size();
+ CHECK_EQ(dim0_size, literal->shape().dimensions(0));
+ CHECK_EQ(dim1_size, literal->shape().dimensions(1));
+
+ const int64 num_elements = dim1_size * dim0_size;
+ Reserve(num_elements, literal);
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto value : inner_list) {
+ Set(literal, {dim0, dim1}, value);
+ ++dim1;
+ }
+ CHECK_EQ(dim1_size, dim1);
+ ++dim0;
+ }
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR2(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ Literal* literal) {
+ PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2(), literal);
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR2FromArray2DWithLayout(
+ const Array2D<NativeT>& values, const Layout& layout, Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(),
+ {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major()));
+
+ const int64 dim1_size = values.width();
+ const int64 dim0_size = values.height();
+ CHECK_EQ(dim0_size, literal->shape().dimensions(0));
+ CHECK_EQ(dim1_size, literal->shape().dimensions(1));
+ Reserve(dim1_size * dim0_size, literal);
+ for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
+ for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
+ Set(literal, {dim0, dim1}, values(dim0, dim1));
+ }
+ }
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR2FromArray2D(
+ const Array2D<NativeT>& values, Literal* literal) {
+ PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2(),
+ literal);
+}
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR3FromArray3DWithLayout(
+ const Array3D<NativeT>& values, const Layout& layout, Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(),
+ {values.n1(), values.n2(), values.n3()},
+ AsInt64Slice(layout.minor_to_major()));
+
+ CHECK_EQ(values.n1(), literal->shape().dimensions(0));
+ CHECK_EQ(values.n2(), literal->shape().dimensions(1));
+ CHECK_EQ(values.n3(), literal->shape().dimensions(2));
+ Reserve(values.n1() * values.n2() * values.n3(), literal);
+ for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
+ for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
+ for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
+ Set(literal, {dim0, dim1, dim2}, values(dim0, dim1, dim2));
+ }
+ }
+ }
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR3FromArray3D(
+ const Array3D<NativeT>& values, Literal* literal) {
+ PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3(),
+ literal);
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR4FromArray4DWithLayout(
+ const Array4D<NativeT>& values, const Layout& layout, Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(),
+ {values.planes(), values.depth(), values.height(), values.width()},
+ AsInt64Slice(layout.minor_to_major()));
+
+ CHECK_EQ(values.n1(), literal->shape().dimensions(0));
+ CHECK_EQ(values.n2(), literal->shape().dimensions(1));
+ CHECK_EQ(values.n3(), literal->shape().dimensions(2));
+ CHECK_EQ(values.n4(), literal->shape().dimensions(3));
+ Reserve(values.n1() * values.n2() * values.n3() * values.n4(), literal);
+ for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
+ for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
+ for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
+ for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) {
+ Set(literal, {dim0, dim1, dim2, dim3},
+ values(dim0, dim1, dim2, dim3));
+ }
+ }
+ }
+ }
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateR4FromArray4D(
+ const Array4D<NativeT>& values, Literal* literal) {
+ PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4(),
+ literal);
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::PopulateWithValue(
+ NativeT value, tensorflow::gtl::ArraySlice<int64> dimensions,
+ Literal* literal) {
+ *literal->mutable_shape() = ShapeUtil::MakeShape(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
+ tensorflow::protobuf::RepeatedField<NativeT>* repeated_field =
+ GetMutableRepeatedField<NativeT>(literal);
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal->shape()); ++i) {
+ repeated_field->Add(value);
+ }
+}
+
+template <>
+/* static */ void LiteralUtil::PopulateWithValue(
+ int64 value, tensorflow::gtl::ArraySlice<int64> dimensions,
+ Literal* literal);
+
+template <>
+/* static */ void LiteralUtil::PopulateWithValue(
+ uint64 value, tensorflow::gtl::ArraySlice<int64> dimensions,
+ Literal* literal);
+
+template <typename NativeSrcT, typename NativeDestT>
+/* static */ std::unique_ptr<Literal> LiteralUtil::Convert(
+ const Literal& literal) {
+ auto result_literal = MakeUnique<Literal>();
+ Shape result_shape = literal.shape();
+ result_shape.set_element_type(
+ primitive_util::NativeToPrimitiveType<NativeDestT>());
+ *result_literal->mutable_shape() = result_shape;
+ LiteralUtil::Reserve(ShapeUtil::ElementsIn(result_shape),
+ result_literal.get());
+ LiteralUtil::EachCell<NativeSrcT>(
+ literal,
+ [&](tensorflow::gtl::ArraySlice<int64> indices, NativeSrcT value) {
+ LiteralUtil::Set<NativeDestT>(result_literal.get(), indices,
+ static_cast<NativeDestT>(value));
+ });
+ return result_literal;
+}
+
+template <typename NativeT>
+/* static */ void LiteralUtil::Resize(int64 num_elements, NativeT value,
+ Literal* literal) {
+ CHECK_EQ(ShapeUtil::ElementsIn(literal->shape()), num_elements);
+ tensorflow::protobuf::RepeatedField<NativeT>* repeated_field =
+ GetMutableRepeatedField<NativeT>(literal);
+ repeated_field->Resize(num_elements, value);
+}
+
+template <>
+/* static */ void LiteralUtil::Resize(int64 num_elements, int64 value,
+ Literal* literal);
+
+template <>
+/* static */ void LiteralUtil::Resize(int64 num_elements, uint64 value,
+ Literal* literal);
+
+template <typename NativeT>
+/* static */ std::unique_ptr<Literal>
+LiteralUtil::CreateFullWithMonotonicDim0MajorLayout(
+ tensorflow::gtl::ArraySlice<int64> dimensions, NativeT value) {
+ Shape shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ primitive_util::NativeToPrimitiveType<NativeT>(), dimensions);
+ auto literal = MakeUnique<Literal>();
+ *literal->mutable_shape() = shape;
+ Reserve(ShapeUtil::ElementsIn(shape), literal.get());
+ std::vector<int64> index(dimensions.size(), 0);
+ do {
+ Set(literal.get(), index, value);
+ } while (IndexUtil::BumpIndices(shape, &index));
+ return literal;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_LITERAL_UTIL_H_
diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc
new file mode 100644
index 0000000000..09410d5c33
--- /dev/null
+++ b/tensorflow/compiler/xla/literal_util_test.cc
@@ -0,0 +1,622 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/literal_util.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class LiteralUtilTest : public ::testing::Test {
+ protected:
+ LiteralUtilTest() {
+ Array4D<float> arr4d({
+ // clang-format off
+ { // i0=0
+ { // i1=0
+ {1, 2, 3}, // i2=0
+ {4, 5, 6}, // i2=1
+ {7, 8, 9}, // i2=2
+ },
+ { // i1=1
+ {11, 12, 13},
+ {14, 15, 16},
+ {17, 18, 19},
+ },
+ },
+ { // i0=1
+ { // i1=0
+ {101, 102, 103},
+ {104, 105, 106},
+ {107, 108, 109},
+ },
+ { // i1=1
+ {201, 202, 203}, // i2=0
+ {204, 205, 206}, // i2=1
+ {207, 208, 209}, // i2=2
+ },
+ },
+ // clang-format on
+ });
+
+ layout_r2_dim0major_ = LayoutUtil::MakeLayout({1, 0});
+ layout_r2_dim0minor_ = LayoutUtil::MakeLayout({0, 1});
+ layout_r3_dim0major_ = LayoutUtil::MakeLayout({2, 1, 0});
+ layout_r3_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2});
+ layout_r4_dim0major_ = LayoutUtil::MakeLayout({3, 2, 1, 0});
+ layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ literal_r4_2x2x3x3_dim0major_ =
+ LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
+ layout_r4_dim0major_);
+ literal_r4_2x2x3x3_dim0minor_ =
+ LiteralUtil::CreateR4FromArray4DWithLayout<float>(arr4d,
+ layout_r4_dim0minor_);
+ }
+
+ Layout layout_r2_dim0major_;
+ Layout layout_r2_dim0minor_;
+ Layout layout_r3_dim0major_;
+ Layout layout_r3_dim0minor_;
+ Layout layout_r4_dim0major_;
+ Layout layout_r4_dim0minor_;
+ std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0major_;
+ std::unique_ptr<Literal> literal_r4_2x2x3x3_dim0minor_;
+};
+
+TEST_F(LiteralUtilTest, LiteralScalarToString) {
+ auto true_lit = LiteralUtil::CreateR0<bool>(true);
+ ASSERT_EQ("true", LiteralUtil::ToString(*true_lit));
+
+ auto false_lit = LiteralUtil::CreateR0<bool>(false);
+ ASSERT_EQ("false", LiteralUtil::ToString(*false_lit));
+
+ auto u32_lit = LiteralUtil::CreateR0<uint32>(42);
+ ASSERT_EQ("42", LiteralUtil::ToString(*u32_lit));
+
+ auto s32_lit = LiteralUtil::CreateR0<int32>(-999);
+ ASSERT_EQ("-999", LiteralUtil::ToString(*s32_lit));
+
+ auto f32_lit = LiteralUtil::CreateR0<float>(3.14f);
+ ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit));
+}
+
+TEST_F(LiteralUtilTest, LiteralVectorToString) {
+ auto pred_vec = LiteralUtil::CreateR1<bool>({true, false, true});
+ ASSERT_EQ("{101}", LiteralUtil::ToString(*pred_vec));
+}
+
+TEST_F(LiteralUtilTest, R2ToString) {
+ const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}});
+ const string expected = R"(s32[3,2] {
+ { 1, 2 },
+ { 3, 4 },
+ { 5, 6 },
+})";
+ ASSERT_EQ(expected, LiteralUtil::ToString(*literal));
+}
+
+TEST_F(LiteralUtilTest, R3ToString) {
+ const auto literal =
+ LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}});
+ const string expected = R"(s32[3,2,1] {
+{ { 1 },
+ { 2 } },
+{ { 3 },
+ { 4 } },
+{ { 5 },
+ { 6 } }
+})";
+ ASSERT_EQ(expected, LiteralUtil::ToString(*literal));
+}
+
+TEST_F(LiteralUtilTest, TupleToString) {
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+ const string expected = R"((f32[], f32[2,2]) (
+1,
+f32[2,2] {
+ { 1, 2 },
+ { 3, 4 },
+},
+))";
+ ASSERT_EQ(expected, LiteralUtil::ToString(*tuple));
+}
+
+TEST_F(LiteralUtilTest, CreateR3FromArray3d) {
+ // clang-format off
+ Array3D<float> array_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ // clang-format on
+
+ auto literal = LiteralUtil::CreateR3FromArray3D(array_3d);
+ EXPECT_MATCH(testing::PBToVec<tensorflow::protobuf_int64>(
+ literal->shape().dimensions()),
+ testing::VectorMatcher<tensorflow::protobuf_int64>({2, 3, 2}));
+ string result = LiteralUtil::ToString(*literal);
+ const string expected = R"(f32[2,3,2] {
+{ { 1, 2 },
+ { 3, 4 },
+ { 5, 6 } },
+{ { 7, 8 },
+ { 9, 10 },
+ { 11, 12 } }
+})";
+ ASSERT_EQ(expected, result);
+}
+
+TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
+ // clang-format off
+ auto literal = LiteralUtil::CreateR4Projected<float>({
+ {1, 2},
+ {1001, 1002},
+ {2001, 2002},
+ }, /*projection_p=*/1, /*projection_z=*/2);
+ // clang-format on
+ EXPECT_MATCH(
+ testing::PBToVec(literal->shape().dimensions()),
+ testing::VectorMatcher<tensorflow::protobuf_int64>({1, 2, 3, 2}));
+ string result = LiteralUtil::ToString(*literal);
+ const string expected = R"(f32[1,2,3,2] {
+ { // i0=0
+ { // i1=0
+ {1, 2},
+ {1001, 1002},
+ {2001, 2002},
+ },
+ { // i1=1
+ {1, 2},
+ {1001, 1002},
+ {2001, 2002},
+ },
+ },
+})";
+ ASSERT_EQ(expected, result);
+}
+
+TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) {
+ EXPECT_MATCH(
+ testing::PBToVec<tensorflow::protobuf_int64>(
+ literal_r4_2x2x3x3_dim0major_->shape().dimensions()),
+ testing::VectorMatcher<tensorflow::protobuf_int64>({2, 2, 3, 3}));
+ string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_);
+ const string expected = R"(f32[2,2,3,3] {
+ { // i0=0
+ { // i1=0
+ {1, 2, 3},
+ {4, 5, 6},
+ {7, 8, 9},
+ },
+ { // i1=1
+ {11, 12, 13},
+ {14, 15, 16},
+ {17, 18, 19},
+ },
+ },
+ { // i0=1
+ { // i1=0
+ {101, 102, 103},
+ {104, 105, 106},
+ {107, 108, 109},
+ },
+ { // i1=1
+ {201, 202, 203},
+ {204, 205, 206},
+ {207, 208, 209},
+ },
+ },
+})";
+ ASSERT_EQ(expected, result);
+}
+
+TEST_F(LiteralUtilTest, EachCellR2F32) {
+ // clang-format off
+ auto literal = LiteralUtil::CreateR2<float>({
+ {3.1f, 4.2f},
+ {9.3f, 12.4f},
+ });
+ // clang-format on
+ std::vector<std::tuple<int64, int64, string>> seen;
+ LiteralUtil::EachCellAsString(
+ *literal,
+ [&seen](tensorflow::gtl::ArraySlice<int64> indices, const string& value) {
+ seen.emplace_back(indices[0], indices[1], value);
+ });
+
+ using Elem = std::tuple<int64, int64, string>;
+ std::vector<Elem> expected = {Elem(0, 0, "3.1"), Elem(0, 1, "4.2"),
+ Elem(1, 0, "9.3"), Elem(1, 1, "12.4")};
+ EXPECT_EQ(expected, seen);
+}
+
+TEST_F(LiteralUtilTest, ScalarEquality) {
+ // Test LiteralUtil::Equal with scalars.
+ auto f32_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto f32_42_clone = LiteralUtil::CreateR0<float>(42.0);
+
+ EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42));
+ EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42_clone));
+
+ auto f32_123 = LiteralUtil::CreateR0<float>(123.0);
+ EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f32_123));
+
+ auto f64_42 = LiteralUtil::CreateR0<double>(42.0);
+ EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f64_42));
+}
+
+TEST_F(LiteralUtilTest, NonScalarEquality) {
+ // Test LiteralUtil::Equal with nonscalars.
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_clone = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix_different =
+ LiteralUtil::CreateR2<float>({{4.0, 3.0}, {1.0, 2.0}});
+ auto vector_literal = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+
+ EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix));
+ EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix_clone));
+ EXPECT_FALSE(LiteralUtil::Equal(*matrix, *matrix_different));
+ EXPECT_FALSE(LiteralUtil::Equal(*matrix, *vector_literal));
+ EXPECT_FALSE(LiteralUtil::Equal(*matrix, *scalar));
+}
+
+TEST_F(LiteralUtilTest, DifferentLayoutEquality) {
+ // Test LiteralUtil::Equal with literals which have different layouts.
+ auto colmajor = MakeUnique<Literal>();
+ *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
+ *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ LiteralUtil::Reserve(4, colmajor.get());
+ LiteralUtil::Set<float>(colmajor.get(), {0, 0}, 1.0);
+ LiteralUtil::Set<float>(colmajor.get(), {0, 1}, 2.0);
+ LiteralUtil::Set<float>(colmajor.get(), {1, 0}, 3.0);
+ LiteralUtil::Set<float>(colmajor.get(), {1, 1}, 4.0);
+
+ auto rowmajor = MakeUnique<Literal>();
+ *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2});
+ *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ LiteralUtil::Reserve(4, rowmajor.get());
+ LiteralUtil::Set<float>(rowmajor.get(), {0, 0}, 1.0);
+ LiteralUtil::Set<float>(rowmajor.get(), {0, 1}, 2.0);
+ LiteralUtil::Set<float>(rowmajor.get(), {1, 0}, 3.0);
+ LiteralUtil::Set<float>(rowmajor.get(), {1, 1}, 4.0);
+
+ EXPECT_TRUE(LiteralUtil::Equal(*rowmajor, *colmajor));
+}
+
+TEST_F(LiteralUtilTest, TupleEquality) {
+ // Test LiteralUtil::Equal with tuples.
+ auto scalar = LiteralUtil::CreateR0<float>(1.0);
+ auto matrix = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()});
+
+ // Tuple with the same elements. One element is shared with the original
+ // tuple, the other is a clone of the element in the original tuple.
+ auto scalar_clone = LiteralUtil::CreateR0<float>(1.0);
+ auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()});
+ EXPECT_TRUE(LiteralUtil::Equal(*tuple1, *tuple2));
+
+ // Tuple with elements reversed.
+ auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()});
+ EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *reversed_tuple));
+
+ // Tuple with different value.
+ auto scalar_42 = LiteralUtil::CreateR0<float>(42.0);
+ auto different_tuple =
+ LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()});
+ EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *different_tuple));
+}
+
+TEST_F(LiteralUtilTest, IsAllTuple) {
+ auto element1 = LiteralUtil::CreateR0<float>(0.0);
+ auto element2 = LiteralUtil::CreateR2<float>({{0.0, 0.0}, {0.0, 0.0}});
+ auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()});
+
+ // Tuples should always return false for IsAll.
+ EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 0));
+ EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 1));
+}
+
+TEST_F(LiteralUtilTest, IsAll) {
+ EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(false), 0));
+ EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(true), 1));
+ EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(false), 1));
+ EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(false), 2));
+ EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(true), 0));
+ EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(true), 2));
+ EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<bool>(true), -1));
+
+ // We shouldn't reinterpret int8_min as an unsigned type and then decide that
+ // it is equal to 255.
+ auto int8_min = std::numeric_limits<int8>::min();
+ EXPECT_FALSE(
+ LiteralUtil::IsAll(*LiteralUtil::CreateR0<uint8>(255), int8_min));
+
+ EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<float>(42.0), 42));
+ EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0<float>(42.0001), 42));
+
+ EXPECT_TRUE(
+ LiteralUtil::IsAll(*LiteralUtil::CreateR1<int>({100, 100, 100}), 100));
+ EXPECT_FALSE(LiteralUtil::IsAll(
+ *LiteralUtil::CreateR1<double>({100, 100, 100.001}), 100));
+
+ EXPECT_TRUE(
+ LiteralUtil::IsAll(*LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 8}}), 8));
+ EXPECT_FALSE(
+ LiteralUtil::IsAll(*LiteralUtil::CreateR2<uint64>({{8, 8}, {8, 9}}), 8));
+ EXPECT_FALSE(
+ LiteralUtil::IsAll(*LiteralUtil::CreateR2<uint64>({{9, 8}, {8, 8}}), 8));
+
+ auto uint64_max = std::numeric_limits<uint64>::max();
+ EXPECT_FALSE(LiteralUtil::IsAll(
+ *LiteralUtil::CreateR2<uint64>(
+ {{uint64_max, uint64_max}, {uint64_max, uint64_max}}),
+ -1));
+}
+
+TEST_F(LiteralUtilTest, IsZero) {
+ auto scalar_zero = LiteralUtil::CreateR0<float>(0.0f);
+ auto scalar_one = LiteralUtil::CreateR0<float>(1.0f);
+ EXPECT_TRUE(LiteralUtil::IsZero(*scalar_zero, {}));
+ EXPECT_FALSE(LiteralUtil::IsZero(*scalar_one, {}));
+
+ auto array = LiteralUtil::CreateR2<uint32>({{1, 2, 0, 3}, {1, 0, 1, 2}});
+ EXPECT_FALSE(LiteralUtil::IsZero(*array, {0, 1}));
+ EXPECT_TRUE(LiteralUtil::IsZero(*array, {0, 2}));
+ EXPECT_TRUE(LiteralUtil::IsZero(*array, {1, 1}));
+ EXPECT_FALSE(LiteralUtil::IsZero(*array, {1, 2}));
+}
+
+template <typename T>
+class LiteralUtilTestTemplated : public ::testing::Test {};
+
+using TestedTypes = ::testing::Types<float, int32, uint32>;
+TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes);
+
+TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) {
+ // Make a non-integer for floating point types.
+ TypeParam half = TypeParam(1) / TypeParam(2);
+ auto data = LiteralUtil::CreateR2<TypeParam>({{half, 2}, {3, 4}});
+ const Layout layout01 = LayoutUtil::MakeLayout({0, 1});
+ const Layout layout10 = LayoutUtil::MakeLayout({1, 0});
+
+ auto data01 = LiteralUtil::Relayout(*data, layout01);
+ EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01));
+ EXPECT_TRUE(LiteralUtil::Equal(*data, *data01));
+
+ auto data10 = LiteralUtil::Relayout(*data, layout10);
+ EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10));
+ EXPECT_TRUE(LiteralUtil::Equal(*data, *data10));
+}
+
+TEST_F(LiteralUtilTest, ReshapeR0) {
+ auto original = LiteralUtil::CreateR0<float>(1.7f);
+ auto reshape =
+ LiteralUtil::Reshape(*original, /*shape=*/{}).ConsumeValueOrDie();
+ EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape));
+}
+
+TEST_F(LiteralUtilTest, ReshapeR4) {
+ // clang-format off
+ // F32[1x3x2x4]
+ auto original = LiteralUtil::CreateR4WithLayout<float>({{
+ {{10, 11, 12, 13}, {14, 15, 16, 17}},
+ {{18, 19, 20, 21}, {22, 23, 24, 25}},
+ {{26, 27, 28, 29}, {30, 31, 32, 33}},
+ }}, layout_r4_dim0major_);
+ // F32[1x3x4x2]
+ auto expected = LiteralUtil::CreateR3WithLayout<float>({
+ {{10, 11}, {12, 13}, {14, 15}, {16, 17}},
+ {{18, 19}, {20, 21}, {22, 23}, {24, 25}},
+ {{26, 27}, {28, 29}, {30, 31}, {32, 33}},
+ }, layout_r3_dim0major_);
+ // clang-format on
+ auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie();
+
+ EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape));
+}
+
+TEST_F(LiteralUtilTest, TransposeR0) {
+ auto original = LiteralUtil::CreateR0<float>(1.7f);
+ auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{});
+ EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape));
+}
+
+TEST_F(LiteralUtilTest, TransposeR4) {
+ // clang-format off
+ // F32[1x3x2x4]
+ auto original = LiteralUtil::CreateR4<float>({{
+ {{10, 11, 12, 13}, {14, 15, 16, 17}},
+ {{18, 19, 20, 21}, {22, 23, 24, 25}},
+ {{26, 27, 28, 29}, {30, 31, 32, 33}},
+ }});
+ // clang-format on
+ auto reshape =
+ LiteralUtil::Transpose(*original, /*permutation=*/{2, 3, 0, 1});
+
+ LiteralUtil::EachCell<float>(
+ *reshape, [&](tensorflow::gtl::ArraySlice<int64> indices, float value) {
+ EXPECT_EQ(value,
+ LiteralUtil::Get<float>(*original, {indices[2], indices[3],
+ indices[0], indices[1]}));
+ });
+}
+
+TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
+ // Tests that using Relayout on an array is equivalent to creating it in the
+ // target layout in the first place.
+ auto dim0minor_relaid_to_dim0major = LiteralUtil::Relayout(
+ *literal_r4_2x2x3x3_dim0minor_, layout_r4_dim0major_);
+ EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0major_,
+ *dim0minor_relaid_to_dim0major));
+
+ auto dim0major_relaid_to_dim0minor = LiteralUtil::Relayout(
+ *literal_r4_2x2x3x3_dim0major_, layout_r4_dim0minor_);
+ EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0minor_,
+ *dim0major_relaid_to_dim0minor));
+}
+
+TEST_F(LiteralUtilTest, TestR2LinearLayout) {
+ // Test expected memory layout of R2 dim0-minor (column-major) literal.
+ auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int>(
+ {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_);
+ EXPECT_EQ(mat_dim0minor->s32s_size(), 6);
+ EXPECT_MATCH(testing::PBToVec<int32>(mat_dim0minor->s32s()),
+ testing::VectorMatcher<int32>({1, 4, 2, 5, 3, 6}));
+
+ // Test expected memory layout when using Relayout to row major.
+ auto relaid_mat_to_dim0major =
+ LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_);
+ EXPECT_MATCH(testing::PBToVec<int32>(relaid_mat_to_dim0major->s32s()),
+ testing::VectorMatcher<int32>({1, 2, 3, 4, 5, 6}));
+
+ // Test expected memory layout of R2 created with dim0-major (row-major).
+ auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int>(
+ {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_);
+ EXPECT_EQ(mat_dim0major->s32s_size(), 6);
+ EXPECT_MATCH(testing::PBToVec<int32>(mat_dim0major->s32s()),
+ testing::VectorMatcher<int32>({1, 2, 3, 4, 5, 6}));
+
+ // Test expected memory layout when using Relayout to column major.
+ auto relaid_mat_to_dim0minor =
+ LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_);
+ EXPECT_MATCH(testing::PBToVec<int32>(relaid_mat_to_dim0minor->s32s()),
+ testing::VectorMatcher<int32>({1, 4, 2, 5, 3, 6}));
+}
+
+TEST_F(LiteralUtilTest, TestR3LinearLayout) {
+ // Test expected memory layout of R3 dim0-minor (column-major) literal.
+ Array3D<int> arr3d(
+ // clang-format off
+ {
+ {
+ {1, 2, 3},
+ {4, 5, 6},
+ },
+ {
+ {7, 8, 9},
+ {10, 11, 12},
+ },
+ }); // clang-format on
+ auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
+ arr3d, layout_r3_dim0minor_);
+
+ EXPECT_EQ(lit_dim0minor->s32s_size(), 12);
+ std::vector<int> expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12};
+ EXPECT_MATCH(testing::PBToVec<int32>(lit_dim0minor->s32s()),
+ testing::VectorMatcher<int32>(expected_dim0minor));
+
+ // Test expected memory layout when using Relayout to row major.
+ auto relaid_lit_to_dim0major =
+ LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_);
+ std::vector<int> expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+ EXPECT_MATCH(testing::PBToVec<int32>(relaid_lit_to_dim0major->s32s()),
+ testing::VectorMatcher<int32>(expected_dim0major));
+
+ // Test expected memory layout of R3 created with dim0-major (row-major).
+ auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout<int>(
+ arr3d, layout_r3_dim0major_);
+ EXPECT_EQ(lit_dim0major->s32s_size(), 12);
+ EXPECT_MATCH(testing::PBToVec<int32>(lit_dim0major->s32s()),
+ testing::VectorMatcher<int32>(expected_dim0major));
+
+ // Test expected memory layout when using Relayout to column major.
+ auto relaid_lit_to_dim0minor =
+ LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_);
+ EXPECT_MATCH(testing::PBToVec<int32>(relaid_lit_to_dim0minor->s32s()),
+ testing::VectorMatcher<int32>(expected_dim0minor));
+}
+
+TEST_F(LiteralUtilTest, SliceR0S32) {
+ auto input = LiteralUtil::CreateR0<int32>(1);
+ auto result = LiteralUtil::Slice(*input, {}, {});
+ EXPECT_TRUE(LiteralUtil::Equal(*input, *result));
+}
+
+TEST_F(LiteralUtilTest, SliceR1F32) {
+ auto input = LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0, 4.0, 5.0});
+ auto result = LiteralUtil::Slice(*input, {3}, {4});
+ auto expected = LiteralUtil::CreateR1<float>({4.0});
+ EXPECT_TRUE(LiteralUtil::Equal(*expected, *result));
+}
+
+TEST_F(LiteralUtilTest, SliceR2U32) {
+ auto input_3x4 = LiteralUtil::CreateR2<uint32>(
+ {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}});
+ auto result = LiteralUtil::Slice(*input_3x4, {0, 2}, {2, 4});
+ auto expected = LiteralUtil::CreateR2<uint32>({{3, 4}, {7, 8}});
+ EXPECT_TRUE(LiteralUtil::Equal(*expected, *result));
+}
+
+TEST_F(LiteralUtilTest, SliceR3U32Full) {
+ auto input_2x3x2 = LiteralUtil::CreateR3<uint32>(
+ {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}});
+ auto result = LiteralUtil::Slice(*input_2x3x2, {0, 0, 0}, {2, 3, 2});
+ EXPECT_TRUE(LiteralUtil::Equal(*input_2x3x2, *result));
+}
+
+TEST_F(LiteralUtilTest, PopulateR1S64) {
+ Literal output;
+ LiteralUtil::PopulateR1<int64>({77}, &output);
+ auto expected = LiteralUtil::CreateR1<int64>({77});
+ EXPECT_TRUE(LiteralUtil::Equal(output, *expected));
+}
+
+TEST_F(LiteralUtilTest, PopulateR2U64) {
+ Literal output;
+ LiteralUtil::PopulateR1<uint64>({{77, 88}}, &output);
+ auto expected = LiteralUtil::CreateR1<uint64>({{77, 88}});
+ EXPECT_TRUE(LiteralUtil::Equal(output, *expected));
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR0F32) {
+ Literal output;
+ LiteralUtil::PopulateWithValue<float>(2.5f, {}, &output);
+ auto expected = LiteralUtil::CreateR0<float>(2.5f);
+ EXPECT_TRUE(LiteralUtil::Equal(output, *expected));
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR1S64) {
+ Literal output;
+ LiteralUtil::PopulateWithValue<int64>(-7, {3}, &output);
+ auto expected = LiteralUtil::CreateR1<int64>({-7, -7, -7});
+ EXPECT_TRUE(LiteralUtil::Equal(output, *expected));
+}
+
+TEST_F(LiteralUtilTest, PopulateWithValueR2U64) {
+ Literal output;
+ LiteralUtil::PopulateWithValue<uint64>(42, {2, 2}, &output);
+ auto expected = LiteralUtil::CreateR2<uint64>({{42, 42}, {42, 42}});
+ EXPECT_TRUE(LiteralUtil::Equal(output, *expected));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h
new file mode 100644
index 0000000000..51d0d5f86f
--- /dev/null
+++ b/tensorflow/compiler/xla/map_util.h
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+// FindOrDie returns a const reference to the value associated with
+// the given key if it exists. Crashes otherwise.
+//
+// This is intended as a replacement for operator[] as an rvalue (for reading)
+// when the key is guaranteed to exist.
+template <class Collection>
+const typename Collection::value_type::second_type& FindOrDie(
+ const Collection& collection,
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::const_iterator it = collection.find(key);
+ CHECK(it != collection.end()) << "Map key not found: " << key;
+ return it->second;
+}
+
+// Same as above, but returns a non-const reference.
+template <class Collection>
+typename Collection::value_type::second_type& FindOrDie(
+ Collection& collection, // NOLINT
+ const typename Collection::value_type::first_type& key) {
+ typename Collection::iterator it = collection.find(key);
+ CHECK(it != collection.end()) << "Map key not found: " << key;
+ return it->second;
+}
+
+// Inserts the key-value pair into the collection. Dies if key was already
+// present.
+template <class Collection>
+void InsertOrDie(Collection* const collection,
+ const typename Collection::value_type::first_type& key,
+ const typename Collection::value_type::second_type& data) {
+ auto p = collection->insert(std::make_pair(key, data));
+ CHECK(p.second) << "duplicate key: " << key;
+}
+
+// Returns true if and only if the given collection contains the given key.
+template <class Collection, class Key>
+bool ContainsKey(const Collection& collection, const Key& key) {
+ return collection.find(key) != collection.end();
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_
diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc
new file mode 100644
index 0000000000..21766a2a0c
--- /dev/null
+++ b/tensorflow/compiler/xla/packed_literal_reader.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/packed_literal_reader.h"
+
+#include <limits>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+PackedLiteralReader::PackedLiteralReader(tensorflow::RandomAccessFile* file)
+ : file_(file), offset_(0) {}
+
+PackedLiteralReader::~PackedLiteralReader() { delete file_; }
+
+StatusOr<std::unique_ptr<Literal>> PackedLiteralReader::Read(
+ const Shape& shape, const Layout* layout) {
+ VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape)
+ << " layout: "
+ << (layout == nullptr ? "<none>" : layout->ShortDebugString());
+ auto result = MakeUnique<Literal>();
+ *result->mutable_shape() = shape;
+ if (layout != nullptr) {
+ TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*layout, shape));
+ *result->mutable_shape()->mutable_layout() = *layout;
+ }
+
+ if (shape.element_type() != F32) {
+ return Unimplemented(
+ "not yet implemented element type for packed literal reading: %s",
+ PrimitiveType_Name(shape.element_type()).c_str());
+ }
+
+ int64 elements = ShapeUtil::ElementsIn(shape);
+ LiteralUtil::Resize(elements, std::numeric_limits<float>::quiet_NaN(),
+ result.get());
+ tensorflow::protobuf::RepeatedField<float>* field = result->mutable_f32s();
+ char* data = tensorflow::bit_cast<char*>(field->mutable_data());
+ uint64 bytes = elements * sizeof(float);
+ tensorflow::StringPiece sp;
+ auto s = file_->Read(offset_, bytes, &sp, data);
+ offset_ += sp.size();
+ if (!s.ok()) {
+ return s;
+ } else {
+ // Success: make sure we move the data into the right place if the Read
+ // call decided to return data in somewhere other than "data".
+ CHECK_EQ(sp.size(), bytes);
+ if (sp.data() != data) {
+ memcpy(data, sp.data(), sp.size());
+ }
+ }
+ VLOG(3) << "read shape from file: " << ShapeUtil::HumanString(shape);
+ return std::move(result);
+}
+
+bool PackedLiteralReader::IsExhausted() const {
+ // Try to read a single byte from offset_. If we can't, we've
+ // exhausted the data.
+ char single_byte[1];
+ tensorflow::StringPiece sp;
+ auto s = file_->Read(offset_, sizeof(single_byte), &sp, single_byte);
+ return !s.ok();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/packed_literal_reader.h b/tensorflow/compiler/xla/packed_literal_reader.h
new file mode 100644
index 0000000000..563d978cf5
--- /dev/null
+++ b/tensorflow/compiler/xla/packed_literal_reader.h
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PACKED_LITERAL_READER_H_
+#define TENSORFLOW_COMPILER_XLA_PACKED_LITERAL_READER_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Reads packed data from a metadata-less file as requested by a user (who must
+// know its internal format). These are yielded as (structured) literal values.
+class PackedLiteralReader {
+ public:
+ // Ownership of file is passed to this instance -- this instance takes
+ // responsibility for closing it.
+ explicit PackedLiteralReader(tensorflow::RandomAccessFile* file);
+ ~PackedLiteralReader();
+
+ // Yields the next packed literal with shape "shape" as read from the
+ // underlying file stream.
+ //
+ // Layout is optional. If it is not provided, no layout is set on the literal
+ // that is produced.
+ StatusOr<std::unique_ptr<Literal>> Read(const Shape& shape,
+ const Layout* layout = nullptr);
+
+ // Returns whether the input file has been fully exhausted; i.e. all available
+ // packed literals have been read and we're at the end of the file.
+ bool IsExhausted() const;
+
+ private:
+ tensorflow::RandomAccessFile* file_; // We own and close in our destructor
+ uint64 offset_; // Next file offset to read from
+
+ TF_DISALLOW_COPY_AND_ASSIGN(PackedLiteralReader);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PACKED_LITERAL_READER_H_
diff --git a/tensorflow/compiler/xla/port/BUILD b/tensorflow/compiler/xla/port/BUILD
new file mode 100644
index 0000000000..6fc5f1185c
--- /dev/null
+++ b/tensorflow/compiler/xla/port/BUILD
@@ -0,0 +1,33 @@
+licenses(["notice"]) # Apache 2.0
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+ visibility = ["//tensorflow/compiler/xla:internal"],
+)
+
+cc_library(
+ name = "initialize",
+ hdrs = ["initialize.h"],
+ visibility = [
+ "//tensorflow/compiler/xla:__subpackages__",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/port/initialize.h b/tensorflow/compiler/xla/port/initialize.h
new file mode 100644
index 0000000000..13d9632f97
--- /dev/null
+++ b/tensorflow/compiler/xla/port/initialize.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_
+#define TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_
+
+#undef REGISTER_MODULE_INITIALIZER
+
+namespace xla {
+
+class Initializer {
+ public:
+ typedef void (*InitializerFunc)();
+ explicit Initializer(InitializerFunc func) { func(); }
+};
+
+} // namespace xla
+
+#define REGISTER_INITIALIZER(type, name, body) \
+ static void google_init_##type##_##name() { body; } \
+ xla::Initializer google_initializer_##type##_##name( \
+ google_init_##type##_##name)
+
+#define REGISTER_MODULE_INITIALIZER(name, body) \
+ REGISTER_INITIALIZER(module, name, body)
+
+#endif // TENSORFLOW_COMPILER_XLA_PORT_INITIALIZE_H_
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
new file mode 100644
index 0000000000..e3909ae8e9
--- /dev/null
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -0,0 +1,133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/primitive_util.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace primitive_util {
+
+template <>
+PrimitiveType NativeToPrimitiveType<bool>() {
+ return PRED;
+}
+
+// Unsigned integer
+template <>
+PrimitiveType NativeToPrimitiveType<uint8>() {
+ return U8;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<uint16>() {
+ return U16;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<uint32>() {
+ return U32;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<uint64>() {
+ return U64;
+}
+
+// Signed integer
+template <>
+PrimitiveType NativeToPrimitiveType<int8>() {
+ return S8;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<int16>() {
+ return S16;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<int32>() {
+ return S32;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<int64>() {
+ return S64;
+}
+
+// Floating point
+template <>
+PrimitiveType NativeToPrimitiveType<float>() {
+ return F32;
+}
+
+template <>
+PrimitiveType NativeToPrimitiveType<double>() {
+ return F64;
+}
+
+bool IsFloatingPointType(PrimitiveType type) {
+ return type == F16 || type == F32 || type == F64;
+}
+
+bool IsSignedIntegralType(PrimitiveType type) {
+ return type == S8 || type == S16 || type == S32 || type == S64;
+}
+
+bool IsUnsignedIntegralType(PrimitiveType type) {
+ return type == U8 || type == U16 || type == U32 || type == U64;
+}
+
+bool IsIntegralType(PrimitiveType type) {
+ return IsUnsignedIntegralType(type) || IsSignedIntegralType(type);
+}
+
+int BitWidth(PrimitiveType type) {
+ switch (type) {
+ case PRED:
+ return 1;
+
+ case S8:
+ case U8:
+ return 8;
+
+ case S16:
+ case U16:
+ case F16:
+ return 16;
+
+ case U32:
+ case S32:
+ case F32:
+ return 32;
+
+ case U64:
+ case S64:
+ case F64:
+ return 64;
+
+ case TUPLE:
+ LOG(FATAL) << "TUPLE is an invalid type for BitWidth";
+
+ case OPAQUE:
+ LOG(FATAL) << "OPAQUE is an invalid type for BitWidth";
+
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << type;
+ }
+}
+
+} // namespace primitive_util
+} // namespace xla
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
new file mode 100644
index 0000000000..78f0ee6f59
--- /dev/null
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -0,0 +1,157 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utilities for dealing with XLA primitive types.
+
+#ifndef TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
+
+#include <type_traits>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace primitive_util {
+
+// Returns the XLA primitive type (eg, F32) corresponding to the given
+// template parameter native type (eg, float).
+template <typename NativeT>
+PrimitiveType NativeToPrimitiveType() {
+ // Make the expression depend on the template parameter NativeT so
+ // that this compile-time error only apperas if this function is
+ // instantiated with some concrete type that is not specialized
+ // below.
+ static_assert(!std::is_same<NativeT, NativeT>::value,
+ "Cannot map native type to primitive type.");
+ return PRIMITIVE_TYPE_INVALID;
+}
+
+// Declarations of specializations for each native type which correspond to a
+// XLA primitive type.
+template <>
+PrimitiveType NativeToPrimitiveType<bool>();
+
+// Unsigned integer
+template <>
+PrimitiveType NativeToPrimitiveType<uint8>();
+
+template <>
+PrimitiveType NativeToPrimitiveType<uint16>();
+
+template <>
+PrimitiveType NativeToPrimitiveType<uint32>();
+
+template <>
+PrimitiveType NativeToPrimitiveType<uint64>();
+
+// Signed integer
+template <>
+PrimitiveType NativeToPrimitiveType<int8>();
+
+template <>
+PrimitiveType NativeToPrimitiveType<int16>();
+
+template <>
+PrimitiveType NativeToPrimitiveType<int32>();
+
+template <>
+PrimitiveType NativeToPrimitiveType<int64>();
+
+// Floating point
+template <>
+PrimitiveType NativeToPrimitiveType<float>();
+template <>
+PrimitiveType NativeToPrimitiveType<double>();
+
+bool IsFloatingPointType(PrimitiveType type);
+
+bool IsSignedIntegralType(PrimitiveType type);
+
+bool IsUnsignedIntegralType(PrimitiveType type);
+
+bool IsIntegralType(PrimitiveType type);
+
+// Returns the number of bits in the representation for a given type.
+int BitWidth(PrimitiveType type);
+
+// Returns the native type (eg, float) corresponding to the given template
+// parameter XLA primitive type (eg, F32).
+template <PrimitiveType>
+struct PrimitiveTypeToNative;
+
+// Declarations of specializations for each native type which correspond to a
+// XLA primitive type.
+template <>
+struct PrimitiveTypeToNative<PRED> {
+ using type = bool;
+};
+
+// Unsigned integer
+template <>
+struct PrimitiveTypeToNative<U8> {
+ using type = uint8;
+};
+
+template <>
+struct PrimitiveTypeToNative<U16> {
+ using type = uint16;
+};
+
+template <>
+struct PrimitiveTypeToNative<U32> {
+ using type = uint32;
+};
+
+template <>
+struct PrimitiveTypeToNative<U64> {
+ using type = uint64;
+};
+
+// Signed integer
+template <>
+struct PrimitiveTypeToNative<S8> {
+ using type = int8;
+};
+
+template <>
+struct PrimitiveTypeToNative<S16> {
+ using type = int16;
+};
+
+template <>
+struct PrimitiveTypeToNative<S32> {
+ using type = int32;
+};
+
+template <>
+struct PrimitiveTypeToNative<S64> {
+ using type = int64;
+};
+
+// Floating point
+template <>
+struct PrimitiveTypeToNative<F32> {
+ using type = float;
+};
+template <>
+struct PrimitiveTypeToNative<F64> {
+ using type = double;
+};
+
+} // namespace primitive_util
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PRIMITIVE_UTIL_H_
diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc
new file mode 100644
index 0000000000..adb2e99ad2
--- /dev/null
+++ b/tensorflow/compiler/xla/protobuf_util.cc
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace protobuf_util {
+
+bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
+ const tensorflow::protobuf::Message& m2) {
+ // This is a bit fast and loose, but avoids introducing a dependency on
+ // the much more complex protobuf::util::MessageDifferencer class. For
+ // our purposes we just say that two protobufs are equal if their serialized
+ // representations are equal.
+ string serialized1, serialized2;
+ m1.AppendToString(&serialized1);
+ m2.AppendToString(&serialized2);
+ return (serialized1 == serialized2);
+}
+
+} // namespace protobuf_util
+} // namespace xla
diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h
new file mode 100644
index 0000000000..36247f1bde
--- /dev/null
+++ b/tensorflow/compiler/xla/protobuf_util.h
@@ -0,0 +1,35 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
+
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+namespace protobuf_util {
+
+// Returns true if m1 is equal to m2.
+//
+// WARNING: We use protocol buffer serialization and then check for
+// equality of the serialized representation, which may miss some
+// cases of equality. However, for the purposes of the XLA code
+// base, this form of equality checking is sufficient.
+extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
+ const tensorflow::protobuf::Message& m2);
+} // namespace protobuf_util
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_
diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h
new file mode 100644
index 0000000000..fa67030313
--- /dev/null
+++ b/tensorflow/compiler/xla/ptr_util.h
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
+
+// Utility functions for pointers.
+
+#include <stddef.h>
+
+#include <memory>
+#include <type_traits>
+#include <utility>
+
+namespace xla {
+
+namespace internal {
+
+// Trait to select overloads and return types for MakeUnique.
+template <typename T>
+struct MakeUniqueResult {
+ using scalar = std::unique_ptr<T>;
+};
+template <typename T>
+struct MakeUniqueResult<T[]> {
+ using array = std::unique_ptr<T[]>;
+};
+template <typename T, size_t N>
+struct MakeUniqueResult<T[N]> {
+ using invalid = void;
+};
+
+} // namespace internal
+
+// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type.
+// Example:
+// X* NewX(int, int);
+// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr<X>.
+//
+// WrapUnique is useful for capturing the output of a raw pointer factory.
+// However, prefer 'MakeUnique<T>(args...) over 'WrapUnique(new T(args...))'.
+// auto x = WrapUnique(new X(1, 2)); // works, but nonideal.
+// auto x = MakeUnique<X>(1, 2); // safer, standard, avoids raw 'new'.
+//
+// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]).
+template <typename T>
+std::unique_ptr<T> WrapUnique(T* ptr) {
+ static_assert(!std::is_array<T>::value || std::extent<T>::value != 0,
+ "types T[0] or T[] are unsupported");
+ return std::unique_ptr<T>(ptr);
+}
+
+template <typename T, typename... Args>
+typename internal::MakeUniqueResult<T>::scalar MakeUnique(Args&&... args) {
+ return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
+}
+
+// Overload for array of unknown bound.
+// The allocation of arrays needs to use the array form of new,
+// and cannot take element constructor arguments.
+template <typename T>
+typename internal::MakeUniqueResult<T>::array MakeUnique(size_t n) {
+ return std::unique_ptr<T>(new typename std::remove_extent<T>::type[n]());
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
new file mode 100644
index 0000000000..f03b158fa7
--- /dev/null
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -0,0 +1,540 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/reference_util.h"
+
+#include <array>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
+ const Array2D<float>& operand) {
+ auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
+ for (int64 w = 0; w < operand.width(); ++w) {
+ for (int64 h = 0; h < operand.height(); ++h) {
+ (*result)(w, h) = operand(h, w);
+ }
+ }
+
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
+ const Array2D<float>& lhs, const Array2D<float>& rhs) {
+ CHECK_EQ(lhs.width(), rhs.height());
+ int m = lhs.height();
+ int n = rhs.width();
+ int k = lhs.width();
+ auto result = MakeUnique<Array2D<float>>(m, n);
+ // Because Eigen is a header-oriented library, make sure that the Eigen code
+ // is the same as the code used by the CPU backend (otherwise the linker will
+ // randomly pick *some* definition).
+ __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
+ /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
+ k,
+ /*transpose_lhs=*/0,
+ /*transpose_rhs=*/0);
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
+ const Array2D<double>& lhs, const Array2D<double>& rhs) {
+ CHECK_EQ(lhs.width(), rhs.height());
+ int m = lhs.height();
+ int n = rhs.width();
+ int k = lhs.width();
+ auto result = MakeUnique<Array2D<double>>(m, n);
+ // Because Eigen is a header-oriented library, make sure that the Eigen code
+ // is the same as the code used by the CPU backend (otherwise the linker will
+ // randomly pick *some* definition).
+ __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
+ /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
+ k,
+ /*transpose_lhs=*/0,
+ /*transpose_rhs=*/0);
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
+ const Array2D<float>& input) {
+ auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
+ for (int64 rowno = 0; rowno < input.height(); ++rowno) {
+ for (int64 colno = 0; colno < input.height(); ++colno) {
+ (*result)(rowno, colno) = input(rowno, colno);
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
+ const Array4D<float>& lhs, const Array4D<float>& rhs,
+ std::pair<int64, int64> kernel_stride, Padding padding) {
+ return ConvArray4DGeneralDimensions(
+ lhs, rhs, kernel_stride, padding,
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+}
+
+/* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
+ int64 window_len, int64 stride,
+ Padding padding) {
+ if (padding == Padding::kValid) {
+ return window_util::StridedBound(unpadded_width, window_len, stride);
+ }
+ return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
+}
+
+/* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
+ const Array4D<float>& operand, float init,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
+ std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
+ operand.n4()};
+ auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
+
+ std::vector<int64> window_counts(window.size(), 0);
+ std::vector<int64> pad_low(window.size(), 0);
+ for (int64 i = 0; i < window.size(); ++i) {
+ window_counts[i] =
+ WindowCount(dim_lengths[i], window[i], stride[i], padding);
+ pad_low[i] = padding_both[i].first;
+ }
+ auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
+ window_counts[2], window_counts[3]);
+ // Do a full 4D reduce window.
+ for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
+ for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
+ for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
+ for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
+ int64 i0_base = i0 * stride[0] - pad_low[0];
+ int64 i1_base = i1 * stride[1] - pad_low[1];
+ int64 i2_base = i2 * stride[2] - pad_low[2];
+ int64 i3_base = i3 * stride[3] - pad_low[3];
+
+ float val = init;
+ for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
+ for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
+ for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
+ for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
+ if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
+ i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
+ i0_base + i0_win < operand.n1() &&
+ i1_base + i1_win < operand.n2() &&
+ i2_base + i2_win < operand.n3() &&
+ i3_base + i3_win < operand.n4()) {
+ val += operand(i0_base + i0_win, i1_base + i1_win,
+ i2_base + i2_win, i3_base + i3_win);
+ }
+ }
+ }
+ }
+ }
+ (*result)(i0, i1, i2, i3) = val;
+ }
+ }
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array4D<float>>
+ReferenceUtil::SelectAndScatter4DGePlus(
+ const Array4D<float>& operand, const Array4D<float>& source, float init,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
+ Padding padding = same_padding ? Padding::kSame : Padding::kValid;
+ auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
+ operand.n3(), operand.n4());
+ std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
+ operand.n4()};
+ auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
+ // Fill the output, with the initial value.
+ result->Fill(init);
+
+ std::vector<int64> window_counts(window.size(), 0);
+ std::vector<int64> pad_low(window.size(), 0);
+ for (int64 i = 0; i < window.size(); ++i) {
+ window_counts[i] =
+ WindowCount(dim_lengths[i], window[i], stride[i], padding);
+ pad_low[i] = padding_both[i].first;
+ }
+ CHECK_EQ(window_counts[0], source.n1());
+ CHECK_EQ(window_counts[1], source.n2());
+ CHECK_EQ(window_counts[2], source.n3());
+ CHECK_EQ(window_counts[3], source.n4());
+
+ // Do a full 4D select and Scatter.
+ for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
+ for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
+ for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
+ for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
+ // Now we are inside a window and need to find the max and the argmax.
+ int64 i0_base = i0 * stride[0] - pad_low[0];
+ int64 i1_base = i1 * stride[1] - pad_low[1];
+ int64 i2_base = i2 * stride[2] - pad_low[2];
+ int64 i3_base = i3 * stride[3] - pad_low[3];
+ int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
+ int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
+ int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
+ int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
+ float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
+ for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
+ for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
+ for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
+ for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
+ if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
+ i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
+ i0_base + i0_win < operand.n1() &&
+ i1_base + i1_win < operand.n2() &&
+ i2_base + i2_win < operand.n3() &&
+ i3_base + i3_win < operand.n4()) {
+ float tmp = operand(i0_base + i0_win, i1_base + i1_win,
+ i2_base + i2_win, i3_base + i3_win);
+ if (tmp >= val) {
+ val = tmp;
+ scatter_0 = i0_base + i0_win;
+ scatter_1 = i1_base + i1_win;
+ scatter_2 = i2_base + i2_win;
+ scatter_3 = i3_base + i3_win;
+ }
+ }
+ }
+ }
+ }
+ }
+ (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
+ source(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array4D<float>>
+ReferenceUtil::ConvArray4DGeneralDimensions(
+ const Array4D<float>& lhs, const Array4D<float>& rhs,
+ std::pair<int64, int64> kernel_stride, Padding padding,
+ ConvolutionDimensionNumbers dimension_numbers) {
+ return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
+ {1, 1}, {1, 1}, dimension_numbers);
+}
+
+/* static */ std::unique_ptr<Array4D<float>>
+ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
+ const Array4D<float>& lhs, const Array4D<float>& rhs,
+ std::pair<int64, int64> kernel_stride, Padding padding,
+ std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
+ ConvolutionDimensionNumbers dnums) {
+ std::array<int64, 4> lhs_dimensions{{lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()}};
+ std::array<int64, 4> rhs_dimensions{{rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()}};
+
+ const int64 ksy = kernel_stride.first;
+ const int64 ksx = kernel_stride.second;
+ const int64 dy = lhs_dilation.first;
+ const int64 dx = lhs_dilation.second;
+ const int64 dky = rhs_dilation.first;
+ const int64 dkx = rhs_dilation.second;
+ CHECK_GE(dky, 1);
+ CHECK_GE(dkx, 1);
+ CHECK_GE(dy, 1);
+ CHECK_GE(dx, 1);
+
+ // Get all dimension sizes in lhs and rhs based on the given convolution
+ // dimension configuration.
+ const int64 ix = window_util::DilatedBound(
+ lhs_dimensions[dnums.spatial_dimensions(1)], dx);
+ const int64 iy = window_util::DilatedBound(
+ lhs_dimensions[dnums.spatial_dimensions(0)], dy);
+ const int64 iz = lhs_dimensions[dnums.feature_dimension()];
+ const int64 samples = lhs_dimensions[dnums.batch_dimension()];
+ const int64 kx = window_util::DilatedBound(
+ rhs_dimensions[dnums.kernel_spatial_dimensions(1)], dkx);
+ const int64 ky = window_util::DilatedBound(
+ rhs_dimensions[dnums.kernel_spatial_dimensions(0)], dky);
+ const int64 oz = rhs_dimensions[dnums.kernel_output_feature_dimension()];
+ {
+ const int64 kiz = rhs_dimensions[dnums.kernel_input_feature_dimension()];
+ CHECK_EQ(kiz, iz);
+ }
+
+ if (padding == Padding::kSame) {
+ // We reject same padding with kernel striding, since it's somewhat
+ // nonsensical. We can always follow up to implement this with the desired
+ // semantics if anybody actually uses it.
+ CHECK_EQ(1, ksy);
+ CHECK_EQ(1, ksx);
+ }
+
+ const int64 ox =
+ padding == Padding::kSame ? ix : window_util::StridedBound(ix, kx, ksx);
+ const int64 oy =
+ padding == Padding::kSame ? iy : window_util::StridedBound(iy, ky, ksy);
+ const int64 istartx =
+ padding == Padding::kValid ? 0 : kx % 2 == 0 ? -(kx / 2 - 1) : -kx / 2;
+ const int64 istarty =
+ padding == Padding::kValid ? 0 : ky % 2 == 0 ? -(ky / 2 - 1) : -ky / 2;
+ // Create the output result array and reset the values to 0.
+ std::array<int64, 4> result_dimensions;
+ result_dimensions[dnums.batch_dimension()] = samples;
+ result_dimensions[dnums.feature_dimension()] = oz;
+ result_dimensions[dnums.spatial_dimensions(0)] = oy;
+ result_dimensions[dnums.spatial_dimensions(1)] = ox;
+ auto result =
+ MakeUnique<Array4D<float>>(result_dimensions[0], result_dimensions[1],
+ result_dimensions[2], result_dimensions[3]);
+ result->Fill(0.0);
+
+ // Lambda to access the lhs operand at the given 4D index.
+ const auto lhs_element = [&](int64 batch, int64 feature, int64 height,
+ int64 width) {
+ if (height % dy != 0 || width % dx != 0) {
+ return 0.0f;
+ }
+
+ std::array<int64, 4> index;
+ index[dnums.batch_dimension()] = batch;
+ index[dnums.feature_dimension()] = feature;
+ index[dnums.spatial_dimensions(0)] = height / dy;
+ index[dnums.spatial_dimensions(1)] = width / dx;
+ return lhs(index[0], index[1], index[2], index[3]);
+ };
+
+ // Lambda to access the rhs operand at the given 4D index.
+ const auto rhs_element = [&](int64 kernel_output_feature,
+ int64 kernel_input_feature, int64 height,
+ int64 width) {
+ CHECK_EQ(height % dky, 0);
+ CHECK_EQ(width % dkx, 0);
+ std::array<int64, 4> index;
+ index[dnums.kernel_output_feature_dimension()] = kernel_output_feature;
+ index[dnums.kernel_input_feature_dimension()] = kernel_input_feature;
+ index[dnums.kernel_spatial_dimensions(0)] = height / dky;
+ index[dnums.kernel_spatial_dimensions(1)] = width / dkx;
+ return rhs(index[0], index[1], index[2], index[3]);
+ };
+
+ // Lambda to access the result data at the given 4D index.
+ const auto result_element = [&](int64 batch, int64 kernel_output_feature,
+ int64 height, int64 width) -> float& {
+ std::array<int64, 4> index;
+ index[dnums.batch_dimension()] = batch;
+ index[dnums.feature_dimension()] = kernel_output_feature;
+ index[dnums.spatial_dimensions(0)] = height;
+ index[dnums.spatial_dimensions(1)] = width;
+ return (*result)(index[0], index[1], index[2], index[3]);
+ };
+
+ for (int64 oyi = 0; oyi < oy; ++oyi) {
+ for (int64 oxi = 0; oxi < ox; ++oxi) {
+ for (int64 sample = 0; sample < samples; ++sample) {
+ for (int64 izi = 0; izi < iz; ++izi) {
+ for (int64 ozi = 0; ozi < oz; ++ozi) {
+ for (int64 kyi = 0; kyi < ky; kyi += dky) {
+ for (int64 kxi = 0; kxi < kx; kxi += dkx) {
+ int64 iyi = istarty + ksy * oyi + kyi;
+ int64 ixi = istartx + ksx * oxi + kxi;
+ float input = (iyi >= iy || ixi >= ix || iyi < 0 || ixi < 0)
+ ? 0.0
+ : lhs_element(sample, izi, iyi, ixi);
+ float gain = rhs_element(ozi, izi, kyi, kxi);
+ float addend = input * gain;
+ result_element(sample, ozi, oyi, oxi) += addend;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<std::vector<float>>
+ReferenceUtil::ReduceToColArray2D(
+ const Array2D<float>& matrix, float init,
+ std::function<float(float, float)> reduce_function) {
+ int64 rows = matrix.height();
+ int64 cols = matrix.width();
+ auto result = MakeUnique<std::vector<float>>();
+ for (int64 i = 0; i < rows; ++i) {
+ float acc = init;
+ for (int64 j = 0; j < cols; ++j) {
+ acc = reduce_function(acc, matrix(i, j));
+ }
+ result->push_back(acc);
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<std::vector<float>>
+ReferenceUtil::ReduceToRowArray2D(
+ const Array2D<float>& matrix, float init,
+ std::function<float(float, float)> reduce_function) {
+ int64 rows = matrix.height();
+ int64 cols = matrix.width();
+ auto result = MakeUnique<std::vector<float>>();
+ for (int64 i = 0; i < cols; ++i) {
+ float acc = init;
+ for (int64 j = 0; j < rows; ++j) {
+ acc = reduce_function(acc, matrix(j, i));
+ }
+ result->push_back(acc);
+ }
+ return result;
+}
+
+/*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
+ const Array4D<float>& array, float init,
+ tensorflow::gtl::ArraySlice<int64> dims,
+ std::function<float(float, float)> reduce_function) {
+ std::vector<float> result;
+ CHECK_EQ(dims.size(), 3);
+ const std::set<int64> dim_set(dims.begin(), dims.end());
+ CHECK_EQ(dim_set.size(), 3);
+ for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
+ for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
+ ++a1) {
+ for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
+ ++a2) {
+ for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
+ ++a3) {
+ float accumulator = init;
+ for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
+ ++i0) {
+ for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
+ ++i1) {
+ for (int64 i2 = 0;
+ i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
+ for (int64 i3 = 0;
+ i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
+ accumulator = reduce_function(
+ accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
+ }
+ }
+ }
+ }
+ result.push_back(accumulator);
+ }
+ }
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
+ const Array3D<float>& array, float init,
+ tensorflow::gtl::ArraySlice<int64> dims,
+ std::function<float(float, float)> reduce_function) {
+ CHECK_EQ(dims.size(), 1);
+ int64 rows = dims[0] == 0 ? array.n2() : array.n1();
+ int64 cols = dims[0] == 2 ? array.n2() : array.n3();
+ auto result = MakeUnique<Array2D<float>>(rows, cols);
+ result->Fill(init);
+ for (int i0 = 0; i0 < array.n1(); ++i0) {
+ for (int i1 = 0; i1 < array.n2(); ++i1) {
+ for (int i2 = 0; i2 < array.n3(); ++i2) {
+ int64 row = dims[0] == 0 ? i1 : i0;
+ int64 col = dims[0] == 2 ? i1 : i2;
+ (*result)(row, col) =
+ reduce_function((*result)(row, col), array(i0, i1, i2));
+ }
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
+ const Array2D<float>& matrix,
+ const std::function<float(float)>& map_function) {
+ int64 rows = matrix.height();
+ int64 cols = matrix.width();
+ auto result = MakeUnique<Array2D<float>>(rows, cols);
+ for (int64 i = 0; i < rows; ++i) {
+ for (int64 j = 0; j < cols; ++j) {
+ (*result)(i, j) = map_function(matrix(i, j));
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
+ const Array2D<float>& lhs, const Array2D<float>& rhs,
+ const std::function<float(float, float)>& map_function) {
+ CHECK_EQ(lhs.height(), rhs.height());
+ CHECK_EQ(lhs.width(), rhs.width());
+ int64 rows = lhs.height();
+ int64 cols = rhs.width();
+ auto result = MakeUnique<Array2D<float>>(rows, cols);
+ for (int64 i = 0; i < rows; ++i) {
+ for (int64 j = 0; j < cols; ++j) {
+ (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
+ const Array2D<float>& matrix,
+ const std::function<float(float, int64, int64)>& map_function) {
+ int64 rows = matrix.height();
+ int64 cols = matrix.width();
+ auto result = MakeUnique<Array2D<float>>(rows, cols);
+ for (int64 i = 0; i < rows; ++i) {
+ for (int64 j = 0; j < cols; ++j) {
+ (*result)(i, j) = map_function(matrix(i, j), i, j);
+ }
+ }
+ return result;
+}
+
+/* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::PadArray2D(
+ const Array2D<float>& operand, const PaddingConfig& padding,
+ const float pad) {
+ int64 in0 = operand.n1();
+ int64 high_padding0 = padding.dimensions(0).edge_padding_high();
+ int64 low_padding0 = padding.dimensions(0).edge_padding_low();
+ int64 interior_padding0 = padding.dimensions(0).interior_padding();
+ int64 out0 =
+ in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
+ int64 in1 = operand.n2();
+ int64 high_padding1 = padding.dimensions(1).edge_padding_high();
+ int64 low_padding1 = padding.dimensions(1).edge_padding_low();
+ int64 interior_padding1 = padding.dimensions(1).interior_padding();
+ int64 out1 =
+ in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
+ auto result = MakeUnique<Array2D<float>>(out0, out1);
+ result->Fill(pad);
+ int64 i0 = 0;
+ for (int64 o0 = low_padding0; o0 < out0 - high_padding0;
+ o0 += interior_padding0 + 1) {
+ int64 i1 = 0;
+ for (int64 o1 = low_padding1; o1 < out1 - high_padding1;
+ o1 += interior_padding1 + 1) {
+ (*result)(o0, o1) = operand(i0, i1);
+ ++i1;
+ }
+ ++i0;
+ }
+ return result;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
new file mode 100644
index 0000000000..27421b2ac4
--- /dev/null
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -0,0 +1,382 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
+
+#include <array>
+#include <functional>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Utility class for reference implementations of linear algebra routines.
+class ReferenceUtil {
+ public:
+ // Returns the result of a transpose operation on the input matrix.
+ static std::unique_ptr<Array2D<float>> TransposeArray2D(
+ const Array2D<float>& operand);
+
+ // Returns the result of a matrix multiply `lhs x rhs`.
+ static std::unique_ptr<Array2D<float>> MatmulArray2D(
+ const Array2D<float>& lhs, const Array2D<float>& rhs);
+ static std::unique_ptr<Array2D<double>> MatmulArray2D(
+ const Array2D<double>& lhs, const Array2D<double>& rhs);
+
+ // Converts the input operand to use f64 values instead of f32 values.
+ static std::unique_ptr<Array2D<double>> Array2DF32ToF64(
+ const Array2D<float>& input);
+
+ // Returns the result of a convolution `lhs <conv> rhs`, with the default
+ // convolution dimension numbers returned from
+ // ComputationBuilder::CreateDefaultConvDimensionNumbers().
+ static std::unique_ptr<Array4D<float>> ConvArray4D(
+ const Array4D<float>& lhs, const Array4D<float>& rhs,
+ std::pair<int64, int64> kernel_stride, Padding padding);
+
+ // Returns the result of a convolution `lhs <conv> rhs`, with the given
+ // convolution dimension numbers.
+ static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensions(
+ const Array4D<float>& lhs, const Array4D<float>& rhs,
+ std::pair<int64, int64> kernel_stride, Padding padding,
+ ConvolutionDimensionNumbers dimension_numbers);
+
+ // Returns the result of a convolution `lhs <conv> rhs`, with the given
+ // dilation factors.
+ static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
+ const Array4D<float>& lhs, const Array4D<float>& rhs,
+ std::pair<int64, int64> stride, Padding padding,
+ std::pair<int64, int64> lhs_dilation,
+ std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums);
+
+ // Returns the result of reducing a matrix to a column vector. init is the
+ // initial value for the reduce operation, and reduce_function is the function
+ // to apply for each reduction step.
+ static std::unique_ptr<std::vector<float>> ReduceToColArray2D(
+ const Array2D<float>& matrix, float init,
+ std::function<float(float, float)> reduce_function);
+
+ // Returns the result of reducing a matrix to a row vector. init is the
+ // initial value for the reduce operation, and reduce_function is the function
+ // to apply for each reduction step.
+ static std::unique_ptr<std::vector<float>> ReduceToRowArray2D(
+ const Array2D<float>& matrix, float init,
+ std::function<float(float, float)> reduce_function);
+
+ // Performs a R2=>R1 reduction by reducing away the dimension specified in
+ // 'dimension_to_reduce'.
+ template <typename T>
+ static std::vector<T> ReduceR2ToR1(const Array2D<T>& input,
+ int dimension_to_reduce, T init,
+ std::function<T(T, T)> freduce) {
+ std::vector<T> result(dimension_to_reduce == 0 ? input.n2() : input.n1(),
+ init);
+ for (int i0 = 0; i0 < input.n1(); ++i0) {
+ for (int i1 = 0; i1 < input.n2(); ++i1) {
+ int output = dimension_to_reduce == 0 ? i1 : i0;
+ result[output] = freduce(result[output], input(i0, i1));
+ }
+ }
+ return result;
+ }
+
+ // Returns the result of reducing the 4D array to a vector, reducing away
+ // the dimensions specified in dims.
+ static std::vector<float> Reduce4DTo1D(
+ const Array4D<float>& array, float init,
+ tensorflow::gtl::ArraySlice<int64> dims,
+ std::function<float(float, float)> reduce_function);
+
+ // Returns the result of reducing the 3D array to a 2D array, reducing away
+ // the dimensions specified in dims.
+ static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
+ const Array3D<float>& array, float init,
+ tensorflow::gtl::ArraySlice<int64> dims,
+ std::function<float(float, float)> reduce_function);
+
+ // Applies map_function to each element in the input (2D array) and returns
+ // the result.
+ static std::unique_ptr<Array2D<float>> MapArray2D(
+ const Array2D<float>& matrix,
+ const std::function<float(float)>& map_function);
+
+ // Applies map_function to each pair of corresponding elements in the two
+ // inputs arrays and returns the result.
+ static std::unique_ptr<Array2D<float>> MapArray2D(
+ const Array2D<float>& lhs, const Array2D<float>& rhs,
+ const std::function<float(float, float)>& map_function);
+
+ // Number of windows in a given dimension. Calculation taken from
+ // xla::MakePadding().
+ static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride,
+ Padding padding);
+
+ // Performs a 4D window reduction with Add as the function to apply.
+ static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
+ const Array4D<float>& operand, float init,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+
+ // Performs select and scatter with Greater Than or equal as the select, plus
+ // as the scatter, and Same Padding.
+ static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
+ const Array4D<float>& operand, const Array4D<float>& source, float init,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding);
+
+ // Concatenates the lhs and rhs arrays along the concatenate_dimension.
+ // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
+ // concatenated, so the arrays are stacked on top of each other.
+ template <typename T>
+ static std::unique_ptr<Array2D<T>> Concat2D(const Array2D<T>& lhs,
+ const Array2D<T>& rhs,
+ int concatenate_dimension) {
+ CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
+ auto result = MakeUnique<Array2D<T>>(
+ concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
+ concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
+ for (int64 i0 = 0; i0 < result->n1(); ++i0) {
+ for (int64 i1 = 0; i1 < result->n2(); ++i1) {
+ // If we exceed the bounds of the LHS, draw from the RHS, where the
+ // result index is adjusted by the number of values present in the LHS.
+ (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2()
+ ? lhs(i0, i1)
+ : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
+ i1 >= lhs.n2() ? i1 - lhs.n2() : i1);
+ }
+ }
+ return result;
+ }
+
+ // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs
+ // and rhs must have the same dimensions except for the concatenate dimension.
+ template <typename T>
+ static std::unique_ptr<Array3D<T>> Concat3D(const Array3D<T>& lhs,
+ const Array3D<T>& rhs,
+ int concatenate_dimension) {
+ CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
+ std::vector<int64> lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()};
+ std::vector<int64> rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()};
+ std::vector<int64> out_dims = {rhs.n1(), rhs.n2(), rhs.n3()};
+ for (int i = 0; i < 3; ++i) {
+ if (i != concatenate_dimension) {
+ out_dims[i] = lhs_dims[i];
+ CHECK_EQ(lhs_dims[i], rhs_dims[i]);
+ } else {
+ out_dims[i] = lhs_dims[i] + rhs_dims[i];
+ }
+ }
+ auto result = MakeUnique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
+ for (int64 i0 = 0; i0 < result->n1(); ++i0) {
+ for (int64 i1 = 0; i1 < result->n2(); ++i1) {
+ for (int64 i2 = 0; i2 < result->n3(); ++i2) {
+ (*result)(i0, i1, i2) =
+ i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3()
+ ? lhs(i0, i1, i2)
+ : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
+ i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
+ i2 >= lhs.n3() ? i2 - lhs.n3() : i2);
+ }
+ }
+ }
+ return result;
+ }
+
+ // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs
+ // and rhs must have the same dimensions except for the concatenate dimension.
+ template <typename T>
+ static std::unique_ptr<Array4D<T>> Concat4D(const Array4D<T>& lhs,
+ const Array4D<T>& rhs,
+ int concatenate_dimension) {
+ CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
+ std::vector<int64> lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
+ std::vector<int64> rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
+ std::vector<int64> out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
+ for (int i = 0; i < 4; ++i) {
+ if (i != concatenate_dimension) {
+ out_dims[i] = lhs_dims[i];
+ CHECK_EQ(lhs_dims[i], rhs_dims[i]);
+ } else {
+ out_dims[i] = lhs_dims[i] + rhs_dims[i];
+ }
+ }
+ auto result = MakeUnique<Array4D<T>>(out_dims[0], out_dims[1], out_dims[2],
+ out_dims[3]);
+ for (int64 i0 = 0; i0 < result->n1(); ++i0) {
+ for (int64 i1 = 0; i1 < result->n2(); ++i1) {
+ for (int64 i2 = 0; i2 < result->n3(); ++i2) {
+ for (int64 i3 = 0; i3 < result->n4(); ++i3) {
+ (*result)(i0, i1, i2, i3) =
+ i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4()
+ ? lhs(i0, i1, i2, i3)
+ : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
+ i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
+ i2 >= lhs.n3() ? i2 - lhs.n3() : i2,
+ i3 >= lhs.n4() ? i3 - lhs.n4() : i3);
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ // Slices with modulo-wrapping.
+ template <typename T>
+ static std::vector<T> ModSlice1D(const tensorflow::gtl::ArraySlice<T>& input,
+ int64 start, int64 size) {
+ std::vector<T> result;
+ for (int64 i = 0; i < size; ++i) {
+ result.push_back(input[(start + i) % input.size()]);
+ }
+ return result;
+ }
+
+ // Slices the input array given starting indices in each dimension and limit
+ // indices in each dimension.
+ template <typename T>
+ static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
+ std::array<int64, 2> starts,
+ std::array<int64, 2> limits) {
+ CHECK_LE(starts[0], input.n1());
+ CHECK_LE(starts[1], input.n2());
+ CHECK_LE(limits[0], input.n1());
+ CHECK_LE(limits[1], input.n2());
+ auto result =
+ MakeUnique<Array2D<T>>(limits[0] - starts[0], limits[1] - starts[1]);
+ for (int64 i0 = 0; i0 < result->n1(); ++i0) {
+ for (int64 i1 = 0; i1 < result->n2(); ++i1) {
+ (*result)(i0, i1) = input(starts[0] + i0, starts[1] + i1);
+ }
+ }
+ return result;
+ }
+
+ template <typename T>
+ static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
+ std::array<int64, 4> starts,
+ std::array<int64, 4> limits) {
+ CHECK_LE(starts[0], input.n1());
+ CHECK_LE(starts[1], input.n2());
+ CHECK_LE(starts[2], input.n3());
+ CHECK_LE(starts[3], input.n4());
+ CHECK_LE(limits[0], input.n1());
+ CHECK_LE(limits[1], input.n2());
+ CHECK_LE(limits[2], input.n3());
+ CHECK_LE(limits[3], input.n4());
+ auto result =
+ MakeUnique<Array4D<T>>(limits[0] - starts[0], limits[1] - starts[1],
+ limits[2] - starts[2], limits[3] - starts[3]);
+ for (int64 i0 = 0; i0 < result->n1(); ++i0) {
+ for (int64 i1 = 0; i1 < result->n2(); ++i1) {
+ for (int64 i2 = 0; i2 < result->n3(); ++i2) {
+ for (int64 i3 = 0; i3 < result->n4(); ++i3) {
+ (*result)(i0, i1, i2, i3) = input(starts[0] + i0, starts[1] + i1,
+ starts[2] + i2, starts[3] + i3);
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ template <typename T>
+ static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
+ std::array<int64, 3> starts,
+ std::array<int64, 3> limits) {
+ CHECK_LE(starts[0], input.n1());
+ CHECK_LE(starts[1], input.n2());
+ CHECK_LE(starts[2], input.n3());
+ CHECK_LE(limits[0], input.n1());
+ CHECK_LE(limits[1], input.n2());
+ CHECK_LE(limits[2], input.n3());
+ auto result = MakeUnique<Array3D<T>>(
+ limits[0] - starts[0], limits[1] - starts[1], limits[2] - starts[2]);
+ for (int64 i0 = 0; i0 < result->n1(); ++i0) {
+ for (int64 i1 = 0; i1 < result->n2(); ++i1) {
+ for (int64 i2 = 0; i2 < result->n3(); ++i2) {
+ (*result)(i0, i1, i2) =
+ input(starts[0] + i0, starts[1] + i1, starts[2] + i2);
+ }
+ }
+ }
+ return result;
+ }
+
+ // Applies map_function to each element in the input (2D array) and returns
+ // the result.
+ // (row, column) index of each element is also provided as arguments to
+ // map_function.
+ static std::unique_ptr<Array2D<float>> MapWithIndexArray2D(
+ const Array2D<float>& matrix,
+ const std::function<float(float, int64, int64)>& map_function);
+
+ // Applies map_function to each element in the input (4D array) and returns
+ // the result.
+ template <typename F>
+ static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& input,
+ F&& map_function) {
+ return MapWithIndexArray4D(input,
+ [&](float value, int64, int64, int64, int64) {
+ return map_function(value);
+ });
+ }
+
+ // Applies map_function to each element in the input (4D array) and returns
+ // the result.
+ // (plane, depth, height, width) index of each element is also provided as
+ // arguments to map_function.
+ template <typename F>
+ static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
+ const Array4D<float>& input, F&& map_function) {
+ auto result = MakeUnique<Array4D<float>>(input.planes(), input.depth(),
+ input.height(), input.width());
+ for (int64 plane = 0; plane < input.planes(); ++plane) {
+ for (int64 depth = 0; depth < input.depth(); ++depth) {
+ for (int64 height = 0; height < input.height(); ++height) {
+ for (int64 width = 0; width < input.width(); ++width) {
+ (*result)(plane, depth, height, width) =
+ map_function(input(plane, depth, height, width), plane, depth,
+ height, width);
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ // Returns the result of a 2D pad on an input matrix.
+ static std::unique_ptr<Array2D<float>> PadArray2D(
+ const Array2D<float>& operand, const PaddingConfig& padding,
+ const float pad);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc
new file mode 100644
index 0000000000..c53351ca93
--- /dev/null
+++ b/tensorflow/compiler/xla/reference_util_test.cc
@@ -0,0 +1,306 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/reference_util.h"
+
+#include <cmath>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+// Tests linear algebra routines implemented in ReferenceUtil class.
+// TODO(b/23829238): Currently missing tests for the convolution routine.
+class ReferenceUtilTest : public ::testing::Test {
+ protected:
+ ReferenceUtilTest() {
+ matrix_ = MakeUnique<Array2D<float>>(rows_, cols_);
+ // [1.f 2.f 3.f]
+ // [4.f 5.f 6.f]
+ for (int64 i = 0; i < rows_; ++i) {
+ for (int64 j = 0; j < cols_; ++j) {
+ (*matrix_)(i, j) = i * cols_ + j + 1;
+ }
+ }
+ }
+
+ const int64 rows_ = 2;
+ const int64 cols_ = 3;
+ std::unique_ptr<Array2D<float>> matrix_;
+};
+
+TEST_F(ReferenceUtilTest, TransposeArray2D) {
+ auto result = ReferenceUtil::TransposeArray2D(*matrix_);
+ auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
+ *result_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, MatmulArray2D) {
+ Array2D<float> rhs({
+ {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f},
+ });
+ auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
+ auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
+ LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
+ *result_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
+ auto add = [](float lhs, float rhs) { return lhs + rhs; };
+ auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
+ auto result_literal = LiteralUtil::CreateR1<float>(*result);
+ LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *result_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
+ auto add = [](float lhs, float rhs) { return lhs + rhs; };
+ auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
+ auto result_literal = LiteralUtil::CreateR1<float>(*result);
+ LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *result_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, MapArray2D) {
+ auto identity = [](float value) { return log(exp(value)); };
+ auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
+ auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
+ LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
+ auto add_index = [](float value, int64 row, int64 col) {
+ return value + row + col;
+ };
+ auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
+ auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
+ LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
+ *result_literal, ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, MapArray4D) {
+ auto input = MakeUnique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
+ /*height=*/4, /*width=*/5);
+ input->FillWithMultiples(1.0f);
+ auto multiply_by_two = [](float value) { return 2 * value; };
+ auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
+ auto result_literal = LiteralUtil::CreateR4FromArray4D(*result);
+
+ Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
+ expected.FillWithMultiples(2.0f);
+ LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
+ auto input = MakeUnique<Array4D<float>>(/*planes=*/2, /*depth=*/3,
+ /*height=*/4, /*width=*/5);
+ input->FillWithMultiples(1.0f);
+ auto subtract_index = [](float value, int64 plane, int64 depth, int64 height,
+ int64 width) {
+ return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
+ };
+ auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
+ auto result_literal = LiteralUtil::CreateR4FromArray4D(*result);
+
+ Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
+ expected.Fill(0.0f);
+ LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, ConvWithSamePadding) {
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> weights(1, 1, 2, 2);
+ // clang-format off
+ weights.FillWithYX(Array2D<float>({
+ {5, 6},
+ {7, 8},
+ }));
+ // clang-format on
+ std::unique_ptr<Array4D<float>> actual =
+ ReferenceUtil::ConvArray4D(input, weights, {1, 1}, Padding::kSame);
+ Array4D<float> expected(1, 1, 4, 4);
+ // clang-format off
+ expected.FillWithYX(Array2D<float>({
+ {100, 126, 152, 76},
+ {204, 230, 256, 124},
+ {308, 334, 360, 172},
+ {149, 160, 171, 80},
+ }));
+ // clang-format on
+
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
+
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, ConvWithValidPadding) {
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> weights(1, 1, 2, 2);
+ // clang-format off
+ weights.FillWithYX(Array2D<float>({
+ {5, 6},
+ {7, 8},
+ }));
+ // clang-format on
+ std::unique_ptr<Array4D<float>> actual =
+ ReferenceUtil::ConvArray4D(input, weights, {1, 1}, Padding::kValid);
+ Array4D<float> expected(1, 1, 3, 3);
+ // clang-format off
+ expected.FillWithYX(Array2D<float>({
+ {1*5+2*6+5*7+6*8, 126, 152},
+ {204, 230, 256},
+ {308, 334, 11*5+12*6+15*7+16*8},
+ }));
+ // clang-format on
+
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
+
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) {
+ // clang-format off
+ // Input dimensions: [feature=2, height=3, batch=1, width=4]
+ Array4D<float> input({
+ {{{1, 2, 3, 4}},
+ {{5, 6, 7, 8}},
+ {{9, 10, 11, 12}}},
+ {{{13, 14, 15, 16}},
+ {{17, 18, 19, 20}},
+ {{21, 22, 23, 24}}}
+ });
+ // Weight dimensions:
+ // [kernel_output_feature=1, height=3, kernel_input_feature=2, width=3]
+ Array4D<float> weight({{
+ {{1, 2, 3},
+ {4, 5, 6}},
+ {{7, 8, 9},
+ {10, 11, 12}},
+ {{13, 14, 15},
+ {16, 17, 18}}
+ }});
+ // clang-format on
+
+ // Set the convolution dimension numbers.
+ ConvolutionDimensionNumbers dimension_numbers;
+ dimension_numbers.set_batch_dimension(2);
+ dimension_numbers.set_feature_dimension(0);
+ dimension_numbers.add_spatial_dimensions(1);
+ dimension_numbers.add_spatial_dimensions(3);
+ dimension_numbers.set_kernel_output_feature_dimension(0);
+ dimension_numbers.set_kernel_input_feature_dimension(2);
+ dimension_numbers.add_kernel_spatial_dimensions(1);
+ dimension_numbers.add_kernel_spatial_dimensions(3);
+
+ std::unique_ptr<Array4D<float>> actual =
+ ReferenceUtil::ConvArray4DGeneralDimensions(
+ input, weight, {1, 1}, Padding::kSame, dimension_numbers);
+ // clang-format off
+ // Result dimensions: [feature=1, height=3, batch=1, width=4]
+ Array4D<float> expected({{
+ {{1110, 1688, 1838, 1226}},
+ {{1683, 2514, 2685, 1761}},
+ {{878, 1280, 1358, 866}}
+ }});
+ // clang-format on
+
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
+
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
+ // clang-format off
+ // Input dimensions: [feature=2, height=3, batch=1, width=4]
+ Array4D<float> input({
+ {{{1, 2, 3, 4}},
+ {{5, 6, 7, 8}},
+ {{9, 10, 11, 12}}},
+ {{{13, 14, 15, 16}},
+ {{17, 18, 19, 20}},
+ {{21, 22, 23, 24}}}
+ });
+ // Weight dimensions:
+ // [kernel_output_feature=1, width=3, kernel_input_feature=2, height=3]
+ Array4D<float> weight({{
+ {{1, 7, 13},
+ {4, 10, 16}},
+ {{2, 8, 14},
+ {5, 11, 17}},
+ {{3, 9, 15},
+ {6, 12, 18}}
+ }});
+ // clang-format on
+
+ // Set the convolution dimension numbers.
+ ConvolutionDimensionNumbers dimension_numbers;
+ dimension_numbers.set_batch_dimension(2);
+ dimension_numbers.set_feature_dimension(0);
+ dimension_numbers.add_spatial_dimensions(1);
+ dimension_numbers.add_spatial_dimensions(3);
+
+ dimension_numbers.set_kernel_output_feature_dimension(0);
+ dimension_numbers.set_kernel_input_feature_dimension(2);
+ dimension_numbers.add_kernel_spatial_dimensions(3);
+ dimension_numbers.add_kernel_spatial_dimensions(1);
+
+ std::unique_ptr<Array4D<float>> actual =
+ ReferenceUtil::ConvArray4DGeneralDimensions(
+ input, weight, {1, 1}, Padding::kValid, dimension_numbers);
+ // clang-format off
+ // Result dimensions: [feature=1, height=1, batch=1, width=2]
+ Array4D<float> expected({{{{2514, 2685}}}});
+ // clang-format on
+
+ auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual);
+
+ LiteralTestUtil::ExpectR4NearArray4D<float>(expected, *actual_literal,
+ ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
new file mode 100644
index 0000000000..fb36443801
--- /dev/null
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -0,0 +1,1216 @@
+# Description:
+# XLA service implementation.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [":friends"])
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library")
+
+xla_proto_library(
+ name = "session_proto",
+ srcs = ["session.proto"],
+ visibility = ["//visibility:public"],
+ deps = ["//tensorflow/compiler/xla:xla_data_proto"],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "shape_inference",
+ srcs = ["shape_inference.cc"],
+ hdrs = ["shape_inference.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "shape_inference_test",
+ srcs = ["shape_inference_test.cc"],
+ deps = [
+ ":shape_inference",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "hlo_opcode_test",
+ srcs = ["hlo_opcode_test.cc"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo",
+ srcs = [
+ "dfs_hlo_visitor.cc",
+ "hlo_computation.cc",
+ "hlo_instruction.cc",
+ "hlo_module.cc",
+ "hlo_opcode.cc",
+ ],
+ hdrs = [
+ "dfs_hlo_visitor.h",
+ "dfs_hlo_visitor_with_default.h",
+ "hlo_computation.h",
+ "hlo_instruction.h",
+ "hlo_module.h",
+ "hlo_opcode.h",
+ ],
+ deps = [
+ ":name_uniquer",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:protobuf_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "versioned_computation_handle",
+ hdrs = ["versioned_computation_handle.h"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "hlo_instruction_test",
+ srcs = ["hlo_instruction_test.cc"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "user_computation",
+ srcs = ["user_computation.cc"],
+ hdrs = ["user_computation.h"],
+ deps = [
+ ":hlo",
+ ":session_proto",
+ ":shape_inference",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "platform_util",
+ srcs = ["platform_util.cc"],
+ hdrs = ["platform_util.h"],
+ deps = [
+ ":compiler",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "backend",
+ srcs = ["backend.cc"],
+ hdrs = ["backend.h"],
+ deps = [
+ ":compiler",
+ ":device_memory_allocator",
+ ":platform_util",
+ ":transfer_manager",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:backend_flags",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
+ name = "service",
+ srcs = ["service.cc"],
+ hdrs = ["service.h"],
+ deps = [
+ ":allocation_tracker",
+ ":backend",
+ ":channel_tracker",
+ ":compilation_cache",
+ ":compiler",
+ ":computation_layout",
+ ":computation_tracker",
+ ":cpu_transfer_manager",
+ ":device_memory_allocator",
+ ":executable",
+ ":execution_tracker",
+ ":hlo",
+ ":hlo_cost_analysis",
+ ":hlo_execution_profile",
+ ":hlo_graph_dumper",
+ ":hlo_module_config",
+ ":platform_util",
+ ":session_proto",
+ ":transfer_manager",
+ ":user_computation",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:service_interface",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/legacy_flags:service_flags",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:regexp_internal",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "local_service",
+ srcs = ["local_service.cc"],
+ hdrs = ["local_service.h"],
+ deps = [
+ ":backend",
+ ":compiler",
+ ":computation_layout",
+ ":computation_tracker",
+ ":device_memory_allocator",
+ ":executable",
+ ":hlo",
+ ":hlo_execution_profile",
+ ":hlo_module_config",
+ ":platform_util",
+ ":service",
+ ":shaped_buffer",
+ ":user_computation",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "cpu_plugin",
+ deps = [
+ ":cpu_transfer_manager",
+ ":service",
+ "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "gpu_plugin",
+ deps = [
+ ":generic_transfer_manager",
+ ":service",
+ "//tensorflow/compiler/xla/service/gpu:gpu_compiler",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core/platform/default/build_config:stream_executor_cuda",
+ ],
+)
+
+cc_library(
+ name = "shaped_buffer",
+ srcs = ["shaped_buffer.cc"],
+ hdrs = ["shaped_buffer.h"],
+ deps = [
+ ":device_memory_allocator",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "executable",
+ srcs = ["executable.cc"],
+ hdrs = ["executable.h"],
+ deps = [
+ ":computation_layout",
+ ":device_memory_allocator",
+ ":hlo",
+ ":hlo_execution_profile",
+ ":hlo_module_config",
+ ":session_proto",
+ ":shaped_buffer",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:service_flags",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "compiler",
+ srcs = ["compiler.cc"],
+ hdrs = ["compiler.h"],
+ deps = [
+ ":executable",
+ ":hlo",
+ ":hlo_module_config",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "transfer_manager",
+ srcs = ["transfer_manager.cc"],
+ hdrs = ["transfer_manager.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "allocation_tracker",
+ srcs = ["allocation_tracker.cc"],
+ hdrs = ["allocation_tracker.h"],
+ deps = [
+ ":backend",
+ ":device_memory_allocator",
+ ":transfer_manager",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "execution_tracker",
+ srcs = ["execution_tracker.cc"],
+ hdrs = ["execution_tracker.h"],
+ deps = [
+ ":backend",
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "computation_tracker",
+ srcs = ["computation_tracker.cc"],
+ hdrs = ["computation_tracker.h"],
+ deps = [
+ ":hlo",
+ ":session_proto",
+ ":user_computation",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "channel_tracker",
+ srcs = ["channel_tracker.cc"],
+ hdrs = ["channel_tracker.h"],
+ deps = [
+ ":hlo",
+ ":session_proto",
+ ":user_computation",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "name_uniquer",
+ srcs = ["name_uniquer.cc"],
+ hdrs = ["name_uniquer.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "buffer_liveness",
+ srcs = [
+ "buffer_liveness.cc",
+ ],
+ hdrs = [
+ "buffer_liveness.h",
+ ],
+ deps = [
+ ":hlo",
+ ":logical_buffer",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "buffer_liveness_test",
+ srcs = ["buffer_liveness_test.cc"],
+ deps = [
+ ":buffer_liveness",
+ ":cpu_plugin",
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "buffer_assignment",
+ srcs = [
+ "buffer_assignment.cc",
+ ],
+ hdrs = [
+ "buffer_assignment.h",
+ ],
+ deps = [
+ ":buffer_liveness",
+ ":hlo",
+ ":logical_buffer",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "buffer_assignment_test",
+ srcs = ["buffer_assignment_test.cc"],
+ deps = [
+ ":buffer_assignment",
+ ":computation_tracker",
+ ":cpu_plugin",
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_query",
+ srcs = ["hlo_query.cc"],
+ hdrs = ["hlo_query.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ ],
+)
+
+cc_library(
+ name = "instruction_fusion",
+ srcs = ["instruction_fusion.cc"],
+ hdrs = ["instruction_fusion.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "instruction_fusion_test",
+ srcs = ["instruction_fusion_test.cc"],
+ deps = [
+ ":instruction_fusion",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "algebraic_simplifier",
+ srcs = ["algebraic_simplifier.cc"],
+ hdrs = ["algebraic_simplifier.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ ":hlo_query",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "algebraic_simplifier_test",
+ srcs = ["algebraic_simplifier_test.cc"],
+ deps = [
+ ":algebraic_simplifier",
+ ":cpu_plugin",
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "reshape_mover",
+ srcs = ["reshape_mover.cc"],
+ hdrs = ["reshape_mover.h"],
+ deps = [
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ ],
+)
+
+cc_test(
+ name = "reshape_mover_test",
+ srcs = ["reshape_mover_test.cc"],
+ deps = [
+ ":hlo",
+ ":reshape_mover",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "inliner",
+ srcs = ["inliner.cc"],
+ hdrs = ["inliner.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ ":hlo_query",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "inliner_test",
+ srcs = ["inliner_test.cc"],
+ deps = [
+ ":cpu_plugin",
+ ":hlo",
+ ":inliner",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "generic_transfer_manager",
+ srcs = ["generic_transfer_manager.cc"],
+ hdrs = ["generic_transfer_manager.h"],
+ deps = [
+ ":transfer_manager",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+ alwayslink = True, # Contains per-platform transfer manager registration
+)
+
+cc_library(
+ name = "cpu_transfer_manager",
+ srcs = ["cpu_transfer_manager.cc"],
+ hdrs = ["cpu_transfer_manager.h"],
+ deps = [
+ ":generic_transfer_manager",
+ ":transfer_manager",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/cpu:cpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+ alwayslink = True, # Contains per-platform transfer manager registration
+)
+
+cc_test(
+ name = "transfer_manager_test",
+ srcs = ["transfer_manager_test.cc"],
+ deps = [
+ ":cpu_transfer_manager",
+ ":generic_transfer_manager",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_cost_analysis",
+ srcs = [
+ "hlo_cost_analysis.cc",
+ ],
+ hdrs = [
+ "hlo_cost_analysis.h",
+ ],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "hlo_cost_analysis_test",
+ srcs = ["hlo_cost_analysis_test.cc"],
+ deps = [
+ ":computation_tracker",
+ ":hlo",
+ ":hlo_cost_analysis",
+ ":local_service",
+ ":service",
+ ":user_computation",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_execution_profile",
+ srcs = ["hlo_execution_profile.cc"],
+ hdrs = ["hlo_execution_profile.h"],
+ deps = [
+ ":hlo",
+ ":hlo_cost_analysis",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_test(
+ name = "hlo_computation_test",
+ srcs = ["hlo_computation_test.cc"],
+ deps = [
+ ":cpu_plugin",
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_binary(
+ name = "graphviz_example",
+ srcs = ["graphviz_example.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_graph_dumper",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "hlo_module_test",
+ srcs = ["hlo_module_test.cc"],
+ deps = [
+ ":cpu_plugin",
+ ":hlo",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "logical_buffer",
+ srcs = [
+ "logical_buffer.cc",
+ ],
+ hdrs = [
+ "logical_buffer.h",
+ ],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "tuple_points_to_analysis",
+ srcs = [
+ "tuple_points_to_analysis.cc",
+ ],
+ hdrs = [
+ "tuple_points_to_analysis.h",
+ ],
+ deps = [
+ ":hlo",
+ ":logical_buffer",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "tuple_points_to_analysis_test",
+ srcs = ["tuple_points_to_analysis_test.cc"],
+ deps = [
+ ":hlo",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "compilation_cache",
+ srcs = [
+ "compilation_cache.cc",
+ ],
+ hdrs = [
+ "compilation_cache.h",
+ ],
+ deps = [
+ ":executable",
+ ":hlo_module_config",
+ ":versioned_computation_handle",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "layout_assignment",
+ srcs = [
+ "layout_assignment.cc",
+ ],
+ hdrs = [
+ "layout_assignment.h",
+ ],
+ deps = [
+ ":computation_layout",
+ ":hlo",
+ ":hlo_graph_dumper",
+ ":hlo_pass",
+ ":logical_buffer",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "copy_insertion",
+ srcs = ["copy_insertion.cc"],
+ hdrs = ["copy_insertion.h"],
+ deps = [
+ ":buffer_liveness",
+ ":hlo",
+ ":hlo_pass",
+ ":logical_buffer",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "copy_insertion_test",
+ srcs = ["copy_insertion_test.cc"],
+ deps = [
+ ":buffer_liveness",
+ ":copy_insertion",
+ ":cpu_plugin",
+ ":hlo",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_dce",
+ srcs = ["hlo_dce.cc"],
+ hdrs = ["hlo_dce.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:status",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "hlo_dce_test",
+ srcs = ["hlo_dce_test.cc"],
+ deps = [
+ ":cpu_plugin",
+ ":hlo",
+ ":hlo_dce",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_test(
+ name = "layout_assignment_test",
+ srcs = ["layout_assignment_test.cc"],
+ deps = [
+ ":algebraic_simplifier",
+ ":computation_layout",
+ ":cpu_plugin",
+ ":hlo",
+ ":layout_assignment",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_pass",
+ hdrs = ["hlo_pass.h"],
+ deps = [
+ ":hlo",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_pass_pipeline",
+ srcs = [
+ "hlo_pass_pipeline.cc",
+ ],
+ hdrs = [
+ "hlo_pass_pipeline.h",
+ ],
+ deps = [
+ ":compiler",
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_cse",
+ srcs = ["hlo_cse.cc"],
+ hdrs = ["hlo_cse.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ ],
+)
+
+cc_test(
+ name = "hlo_cse_test",
+ srcs = ["hlo_cse_test.cc"],
+ deps = [
+ ":cpu_plugin",
+ ":hlo",
+ ":hlo_cse",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "device_memory_allocator",
+ srcs = ["device_memory_allocator.cc"],
+ hdrs = ["device_memory_allocator.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "elemental_ir_emitter",
+ srcs = ["elemental_ir_emitter.cc"],
+ hdrs = ["elemental_ir_emitter.h"],
+ deps = [
+ ":hlo",
+ ":hlo_module_config",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "@llvm//:core",
+ "@llvm//:transform_utils",
+ ],
+)
+
+cc_library(
+ name = "hlo_module_config",
+ srcs = ["hlo_module_config.cc"],
+ hdrs = ["hlo_module_config.h"],
+ deps = [
+ ":computation_layout",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "computation_layout",
+ srcs = ["computation_layout.cc"],
+ hdrs = ["computation_layout.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "hlo_subcomputation_unification",
+ srcs = ["hlo_subcomputation_unification.cc"],
+ hdrs = ["hlo_subcomputation_unification.h"],
+ deps = [
+ ":hlo_pass",
+ ],
+)
+
+cc_test(
+ name = "hlo_subcomputation_unification_test",
+ srcs = ["hlo_subcomputation_unification_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_graph_dumper",
+ ":hlo_subcomputation_unification",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_graph_dumper",
+ srcs = [
+ "hlo_graph_dumper.cc",
+ ],
+ hdrs = ["hlo_graph_dumper.h"],
+ deps = [
+ ":hlo",
+ ":hlo_execution_profile",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "transpose_folding",
+ srcs = ["transpose_folding.cc"],
+ hdrs = ["transpose_folding.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "transpose_folding_test",
+ srcs = ["transpose_folding_test.cc"],
+ deps = [
+ ":hlo",
+ ":transpose_folding",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/gpu:ir_emission_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
new file mode 100644
index 0000000000..fe892e872f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -0,0 +1,938 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+
+#include <algorithm>
+#include <memory>
+#include <numeric>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// Returns whether operand is a literal with the given value.
+bool IsLiteralWithValue(const HloInstruction* operand, int value) {
+ return operand->opcode() == HloOpcode::kConstant &&
+ LiteralUtil::IsAll(operand->literal(), value);
+}
+
+// Returns whether the given transpose produces a result which is bit-wise
+// identical to its operand and thus may be replaced with a bitcast.
+bool TransposeIsBitcast(
+ const HloInstruction* transpose,
+ const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
+ CHECK_EQ(HloOpcode::kTranspose, transpose->opcode());
+ const HloInstruction* operand = transpose->operand(0);
+
+ // Can't insert bitcasts if the compiler used a memory layout which isn't
+ // compatible.
+ if (!valid_bitcast_callback(operand->shape(), transpose->shape())) {
+ return false;
+ }
+
+ return ShapeUtil::TransposeIsBitcast(operand->shape(), transpose->shape(),
+ transpose->dimensions());
+}
+
+// Returns true if the given reshape produces a result which is bit-wise
+// identical to its operand and thus may be replaced with a bitcast.
+//
+// This function is conservative -- even if this function returns false, the
+// reshape may still be a bitcast. For example, a reshape from [28x28] to [784].
+bool ReshapeIsBitcast(
+ const HloInstruction* reshape,
+ const AlgebraicSimplifier::ValidBitcastCallback& valid_bitcast_callback) {
+ CHECK_EQ(HloOpcode::kReshape, reshape->opcode());
+
+ const HloInstruction* operand = reshape->operand(0);
+ // Can't insert bitcasts if the compiler used a memory layout which isn't
+ // compatible.
+ if (!valid_bitcast_callback(operand->shape(), reshape->shape())) {
+ return false;
+ }
+
+ return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape());
+}
+} // namespace
+
+// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
+// algebraic expressions to simplified forms. Note: This only supports
+// simplifications that simply look at the operands of an instruction. For the
+// more general case a worklist based approach would be needed.
+class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
+ public:
+ // Default visitor action is to do nothing and return OK.
+ Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
+ return Status::OK();
+ }
+
+ Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+
+ Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+
+ Status HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) override;
+
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
+ HloInstruction* rhs, const Window& window) override;
+
+ Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+
+ Status HandleLog(HloInstruction* log, HloInstruction* operand) override;
+
+ Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandlePad(HloInstruction* pad) override;
+
+ Status HandlePower(HloInstruction* power, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleReshape(HloInstruction* reshape) override;
+
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) override;
+
+ Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
+
+ Status HandleTranspose(HloInstruction* transpose) override;
+
+ Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+
+ // Returns whether algebraic simplification has occurred.
+ const bool changed() const { return changed_; }
+
+ // Runs the visitor on a computation.
+ static bool Run(
+ HloComputation* computation, bool is_layout_sensitive,
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback);
+
+ private:
+ explicit AlgebraicSimplifierVisitor(
+ HloComputation* computation, bool is_layout_sensitive,
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback)
+ : computation_(computation),
+ is_layout_sensitive_(is_layout_sensitive),
+ valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
+
+ // Convenience method for replacing an instruction with a bitcast.
+ void ReplaceWithBitcast(HloInstruction* instruction);
+
+ // Replace old instruction with new instruction if old and new instructions
+ // have the same shape. Updates uses and root instruction. Returns whether a
+ // replacement was made.
+ bool ReplaceInstructionIfSameShape(HloInstruction* old_instruction,
+ HloInstruction* new_instruction);
+
+ // Returns whether the shape of the output of the given instructions are the
+ // same for the purposes of simplification. If is_layout_sensitive_ is true,
+ // then this tests shape equality including layout (ShapeUtil::Equal). If
+ // is_layout_sensitive_ is false, then the tests shape compatibility
+ // (ShapeUtil::Compatible).
+ bool SameShape(const HloInstruction* lhs, const HloInstruction* rhs) const;
+
+ // Returns whether it was possible to transform `root` to a clamp instruction.
+ // With min a minimum instruction, max a maximum instruction, min_operand a
+ // operand of min and max_operand a operand of max.
+ // Precondition: root is either a minimum or a maximum.
+ bool TransformToClampIfSameShape(HloInstruction* root, HloInstruction* min,
+ HloInstruction* min_operand,
+ HloInstruction* operand, HloInstruction* max,
+ HloInstruction* max_operand);
+
+ // Current HloComputation instance the AlgebraicSimplifierVisitor is
+ // traversing.
+ HloComputation* computation_;
+
+ // Whether algebraic simplification has occurred.
+ bool changed_ = false;
+
+ // Whether layout is considered during transformation.
+ bool is_layout_sensitive_;
+
+ // Callback used to determine if a bitcast is valid.
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback_;
+};
+
+bool AlgebraicSimplifierVisitor::Run(
+ HloComputation* computation, bool is_layout_sensitive,
+ AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback) {
+ AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive,
+ std::move(valid_bitcast_callback));
+ TF_CHECK_OK(computation->root_instruction()->Accept(&visitor));
+ return visitor.changed_;
+}
+
+bool AlgebraicSimplifierVisitor::SameShape(const HloInstruction* lhs,
+ const HloInstruction* rhs) const {
+ if (is_layout_sensitive_) {
+ return ShapeUtil::Equal(lhs->shape(), rhs->shape());
+ } else {
+ return ShapeUtil::Compatible(lhs->shape(), rhs->shape());
+ }
+}
+
+void AlgebraicSimplifierVisitor::ReplaceWithBitcast(
+ HloInstruction* instruction) {
+ CHECK_EQ(1, instruction->operand_count());
+ CHECK_EQ(ShapeUtil::ElementsIn(instruction->shape()),
+ ShapeUtil::ElementsIn(instruction->operand(0)->shape()));
+ CHECK_EQ(ShapeUtil::ByteSizeOf(instruction->shape()),
+ ShapeUtil::ByteSizeOf(instruction->operand(0)->shape()));
+
+ auto bitcast = computation_->AddInstruction(
+ HloInstruction::CreateUnary(instruction->shape(), HloOpcode::kBitcast,
+ instruction->mutable_operand(0)));
+ computation_->ReplaceInstruction(instruction, bitcast);
+ changed_ = true;
+}
+
+bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape(
+ HloInstruction* old_instruction, HloInstruction* new_instruction) {
+ if (!SameShape(old_instruction, new_instruction)) {
+ return false;
+ }
+ computation_->ReplaceInstruction(old_instruction, new_instruction);
+ changed_ = true;
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A + 0 => A
+ VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString();
+ if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) {
+ return Status::OK();
+ }
+ // 0 + A => A
+ VLOG(10) << "trying transform [0 + A => A]: " << add->ToString();
+ if (IsLiteralWithValue(lhs, 0) && ReplaceInstructionIfSameShape(add, rhs)) {
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy,
+ HloInstruction* operand) {
+ // All copies can be eliminated (assuming layout constraints are satisified).
+ ReplaceInstructionIfSameShape(copy, operand);
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A - 0 => A
+ VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString();
+ if (IsLiteralWithValue(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) {
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A/1 => A
+ VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString();
+ if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) {
+ return Status::OK();
+ }
+
+ // exp(A)/exp(B) => exp(A-B)
+ if (lhs->opcode() == HloOpcode::kExp && rhs->opcode() == HloOpcode::kExp) {
+ VLOG(10) << "transform [exp(A)/exp(B) => exp(A-B)]: " << divide->ToString();
+ HloInstruction* subtract =
+ computation_->AddInstruction(HloInstruction::CreateBinary(
+ divide->shape(), HloOpcode::kSubtract, lhs->mutable_operand(0),
+ rhs->mutable_operand(0)));
+ computation_->ReplaceWithNewInstruction(
+ divide, HloInstruction::CreateUnary(divide->shape(), HloOpcode::kExp,
+ subtract));
+ changed_ = true;
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // A*1 => A
+ VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString();
+ if (IsLiteralWithValue(rhs, 1) &&
+ ReplaceInstructionIfSameShape(multiply, lhs)) {
+ return Status::OK();
+ }
+ // 1*A => A
+ VLOG(10) << "trying transform [1*A => A]: " << multiply->ToString();
+ if (IsLiteralWithValue(lhs, 1) &&
+ ReplaceInstructionIfSameShape(multiply, rhs)) {
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log,
+ HloInstruction* operand) {
+ // ln(exp(A)) => A
+ VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString();
+ if (operand->opcode() == HloOpcode::kExp &&
+ ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) {
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleGetTupleElement(
+ HloInstruction* get_tuple_element, HloInstruction* operand) {
+ if (operand->opcode() == HloOpcode::kTuple) {
+ // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i
+ VLOG(10) << "trying transform "
+ << "[get_tuple_element(make_tuple({...,A_i,...}), i)] => A_i: "
+ << get_tuple_element->ToString();
+ if (ReplaceInstructionIfSameShape(
+ get_tuple_element,
+ operand->mutable_operand(get_tuple_element->tuple_index()))) {
+ return Status::OK();
+ }
+ }
+ return Status::OK();
+}
+
+namespace {
+
+// Return whether the given reshape instruction leaves the dimensions at the
+// given input indices unmodified, and returns their output indices.
+//
+// Example:
+// input_dim_indices = {2, 3}
+// input shape = T[a, b, x, y, cd]
+// output shape = T[ab, x, 1, y, c, d]
+// return value = {1, 3}
+//
+// Precondition: input_dim_indices is sorted.
+std::pair<bool, std::vector<int64>> ReshapeLeavesDimensionsUnmodified(
+ const HloInstruction* hlo,
+ tensorflow::gtl::ArraySlice<int64> input_dim_indices) {
+ CHECK_EQ(HloOpcode::kReshape, hlo->opcode());
+ CHECK(std::is_sorted(input_dim_indices.begin(), input_dim_indices.end()));
+
+ std::vector<int64> output_dim_indices;
+ std::vector<std::pair<int64, int64>> unmodified_dims =
+ ShapeUtil::DimensionsUnmodifiedByReshape(hlo->operand(0)->shape(),
+ hlo->shape());
+ size_t i = 0; // index to unmodified_dims
+ for (int64 input_dim_index : input_dim_indices) {
+ // Search unmodified_dims for input_dim_index. We can search from the last
+ // matching position because input_dim_indices is guaranteed to be sorted.
+ while (i < unmodified_dims.size() &&
+ unmodified_dims[i].first < input_dim_index) {
+ ++i;
+ }
+ if (i >= unmodified_dims.size() ||
+ unmodified_dims[i].first != input_dim_index) {
+ return std::make_pair(false, std::vector<int64>());
+ }
+ output_dim_indices.push_back(unmodified_dims[i].second);
+ }
+ return std::make_pair(true, output_dim_indices);
+}
+
+// Returns true if the output of "instruction" is a permutation of the elements
+// of "operand". Precondition: "operand" is an operand of "instruction".
+bool OutputIsPermutationOfOperandElements(HloInstruction* instruction,
+ HloInstruction* operand) {
+ DCHECK(!instruction->OperandIndices(operand).empty());
+ switch (instruction->opcode()) {
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kSort:
+ case HloOpcode::kTranspose:
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Returns true if the output of "instruction" is a subset of the elements of
+// "operand". Precondition: "operand" is an operand of "instruction".
+bool OutputIsSubsetOfOperandElements(HloInstruction* instruction,
+ HloInstruction* operand) {
+ std::vector<int64> operand_indices = instruction->OperandIndices(operand);
+ CHECK(!operand_indices.empty());
+ if (operand_indices.size() != 1) {
+ return false;
+ }
+ int64 operand_index = operand_indices[0];
+ switch (instruction->opcode()) {
+ case HloOpcode::kSlice:
+ CHECK_EQ(0, operand_index);
+ return true;
+ case HloOpcode::kDynamicSlice:
+ return operand_index == 0;
+ default:
+ return false;
+ }
+}
+
+} // namespace
+
+Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) {
+ auto operand = broadcast->mutable_operand(0);
+ // A degenerate broadcast of a reshape that does not change the number of
+ // elements can be replaced by a reshape.
+ if (std::is_sorted(broadcast->dimensions().begin(),
+ broadcast->dimensions().end()) &&
+ ShapeUtil::ElementsIn(broadcast->shape()) ==
+ ShapeUtil::ElementsIn(operand->shape())) {
+ VLOG(10) << "transform broadcast(X) -> reshape(X) where "
+ "n(broadcast(X)) == n(X)";
+ computation_->ReplaceWithNewInstruction(
+ broadcast, HloInstruction::CreateReshape(broadcast->shape(), operand));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ // A broadcast of a reshape which merely inserts 1-sized dimensions can elide
+ // its operand.
+ {
+ bool merely_inserts_or_deletes_1_sized_dimensions;
+ std::vector<int64> inserted_indices, deleted_indices;
+ std::tie(merely_inserts_or_deletes_1_sized_dimensions, deleted_indices,
+ inserted_indices) =
+ operand->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (merely_inserts_or_deletes_1_sized_dimensions &&
+ deleted_indices.empty()) {
+ std::reverse(inserted_indices.begin(), inserted_indices.end());
+ auto dims = broadcast->dimensions();
+ for (auto inserted_index : inserted_indices) {
+ dims.erase(dims.begin() + inserted_index);
+ }
+ computation_->ReplaceWithNewInstruction(
+ broadcast,
+ HloInstruction::CreateBroadcast(broadcast->shape(),
+ operand->mutable_operand(0), dims));
+ changed_ = true;
+ return Status::OK();
+ }
+ }
+
+ // A scalar broadcast feeding an instruction which only permutes (reshape,
+ // transpose, sort, reverse) or selects a subset of operand elements (slice,
+ // dynamic slice) can be replaced with a broadcast directly to the output
+ // shape of the instruction.
+ if (ShapeUtil::IsScalar(operand->shape())) {
+ for (HloInstruction* user : broadcast->users()) {
+ if (OutputIsPermutationOfOperandElements(user, broadcast) ||
+ OutputIsSubsetOfOperandElements(user, broadcast)) {
+ HloInstruction* new_broadcast = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(user->shape(), operand, {}));
+ // Use ReplaceUsesOfInstruction instead of ReplaceWithNewInstruction
+ // because we are replacing an instruction other than the visited
+ // instruction.
+ computation_->ReplaceUsesOfInstruction(user, new_broadcast);
+ changed_ = true;
+ return Status::OK();
+ }
+ }
+ }
+ return Status::OK();
+}
+
+template <PrimitiveType primitive_src_type, PrimitiveType primitive_dest_type>
+static std::unique_ptr<HloInstruction> ConvertIfTypesMatch(
+ const Literal& src_literal) {
+ CHECK_EQ(primitive_src_type, src_literal.shape().element_type());
+
+ return HloInstruction::CreateConstant(
+ LiteralUtil::Convert<typename primitive_util::PrimitiveTypeToNative<
+ primitive_src_type>::type,
+ typename primitive_util::PrimitiveTypeToNative<
+ primitive_dest_type>::type>(src_literal));
+}
+
+template <PrimitiveType primitive_src_type>
+static std::unique_ptr<HloInstruction> ConvertIfDestTypeMatches(
+ const Literal& src_literal, PrimitiveType primitive_dest_type) {
+ switch (primitive_dest_type) {
+#define CONVERT_IF_TYPES_MATCH(type) \
+ case (type): \
+ return ConvertIfTypesMatch<primitive_src_type, (type)>(src_literal);
+ CONVERT_IF_TYPES_MATCH(PRED)
+ CONVERT_IF_TYPES_MATCH(S8)
+ CONVERT_IF_TYPES_MATCH(S32)
+ CONVERT_IF_TYPES_MATCH(S64)
+ CONVERT_IF_TYPES_MATCH(U8)
+ CONVERT_IF_TYPES_MATCH(U32)
+ CONVERT_IF_TYPES_MATCH(U64)
+ CONVERT_IF_TYPES_MATCH(F32)
+ CONVERT_IF_TYPES_MATCH(F64)
+#undef CONVERT_IF_TYPES_MATCH
+ // Other types are not yet supported.
+ default:
+ LOG(FATAL) << "Unimplemented: ConvertIfDestTypeMatches for type "
+ << PrimitiveType_Name(src_literal.shape().element_type());
+ }
+}
+
+static std::unique_ptr<HloInstruction> ConvertIfSrcTypeMatches(
+ const Literal& src_literal, PrimitiveType primitive_dest_type) {
+ switch (src_literal.shape().element_type()) {
+#define CONVERT_IF_DEST_TYPE_MATCHES(type) \
+ case (type): \
+ return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type);
+ CONVERT_IF_DEST_TYPE_MATCHES(PRED)
+ CONVERT_IF_DEST_TYPE_MATCHES(S8)
+ CONVERT_IF_DEST_TYPE_MATCHES(S32)
+ CONVERT_IF_DEST_TYPE_MATCHES(S64)
+ CONVERT_IF_DEST_TYPE_MATCHES(U8)
+ CONVERT_IF_DEST_TYPE_MATCHES(U32)
+ CONVERT_IF_DEST_TYPE_MATCHES(U64)
+ CONVERT_IF_DEST_TYPE_MATCHES(F32)
+ CONVERT_IF_DEST_TYPE_MATCHES(F64)
+#undef CONVERT_IF_DEST_TYPE_MATCHES
+ // Other types are not yet supported.
+ default:
+ LOG(FATAL) << "Unimplemented: ConvertIfSrcTypeMatches for type "
+ << PrimitiveType_Name(src_literal.shape().element_type());
+ }
+}
+
+// A conversion to the same element type as the operand is a nop and can be
+// removed. A conversion of a constant can be simplified by making a new
+// constant.
+Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) {
+ PrimitiveType src_type = operand->shape().element_type();
+ PrimitiveType dest_type = convert->shape().element_type();
+ if (src_type == dest_type) {
+ computation_->ReplaceInstruction(convert, operand);
+ changed_ = true;
+ return Status::OK();
+ }
+ if (operand->opcode() == HloOpcode::kConstant) {
+ const Literal& src_literal = operand->literal();
+ std::unique_ptr<HloInstruction> new_constant =
+ ConvertIfSrcTypeMatches(src_literal, dest_type);
+ computation_->ReplaceWithNewInstruction(convert, std::move(new_constant));
+ changed_ = true;
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) {
+ // The pad instruction does nothing if the output shape is the same as the
+ // input shape, i.e, all paddings are zero.
+ ReplaceInstructionIfSameShape(pad, pad->mutable_operand(0));
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, 0)) {
+ auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
+ LiteralUtil::One(power->shape().element_type())));
+ std::unique_ptr<HloInstruction> ones;
+ if (ShapeUtil::IsScalar(power->shape())) {
+ ones = std::move(one);
+ } else {
+ ones = HloInstruction::CreateBroadcast(
+ power->shape(), computation_->AddInstruction(std::move(one)), {});
+ }
+ computation_->ReplaceWithNewInstruction(power, std::move(ones));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ VLOG(10) << "trying transform [pow(A, 1) => A]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, 1) && ReplaceInstructionIfSameShape(power, lhs)) {
+ return Status::OK();
+ }
+
+ VLOG(10) << "trying transform [pow(A, 2) => A*A]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, 2)) {
+ computation_->ReplaceWithNewInstruction(
+ power, HloInstruction::CreateBinary(power->shape(),
+ HloOpcode::kMultiply, lhs, lhs));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString();
+ if (IsLiteralWithValue(rhs, -1)) {
+ auto* one = computation_->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CloneToUnique(
+ LiteralUtil::One(rhs->shape().element_type()))));
+ computation_->ReplaceWithNewInstruction(
+ power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide,
+ one, lhs));
+ changed_ = true;
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) {
+ auto operand = reshape->mutable_operand(0);
+
+ // Delete no-op reshapes, i.e. where shape = operand shape.
+ if (SameShape(reshape, operand)) {
+ VLOG(10) << "deleting no-op reshape";
+ computation_->ReplaceInstruction(reshape, operand);
+ changed_ = true;
+ return Status::OK();
+ }
+
+ // Merge reshapes.
+ if (HloOpcode::kReshape == operand->opcode()) {
+ computation_->ReplaceWithNewInstruction(
+ reshape, HloInstruction::CreateReshape(reshape->shape(),
+ operand->mutable_operand(0)));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ if (HloOpcode::kBroadcast == reshape->operand(0)->opcode()) {
+ auto opt_dims = ReshapeLeavesDimensionsUnmodified(
+ reshape, reshape->operand(0)->dimensions());
+ if (opt_dims.first) {
+ computation_->ReplaceWithNewInstruction(
+ reshape,
+ HloInstruction::CreateBroadcast(
+ reshape->shape(), reshape->mutable_operand(0)->mutable_operand(0),
+ opt_dims.second));
+ changed_ = true;
+ return Status::OK();
+ }
+ }
+
+ // Make this a bitcast if possible.
+ if (is_layout_sensitive_ &&
+ ReshapeIsBitcast(reshape, valid_bitcast_callback_)) {
+ ReplaceWithBitcast(reshape);
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
+ HloInstruction* operand) {
+ // Delete no-op slices, i.e. where shape = operand shape.
+ if (ReplaceInstructionIfSameShape(slice, operand)) {
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleReduce(
+ HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
+ if (ShapeUtil::ElementsIn(reduce->shape()) ==
+ ShapeUtil::ElementsIn(arg->shape())) {
+ auto reshape = computation_->AddInstruction(
+ HloInstruction::CreateReshape(reduce->shape(), arg));
+ computation_->ReplaceWithNewInstruction(
+ reduce, HloInstruction::CreateMap(reduce->shape(),
+ {reshape, init_value}, function));
+ return Status::OK();
+ }
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) {
+ auto operand = transpose->mutable_operand(0);
+
+ if (std::is_sorted(transpose->dimensions().begin(),
+ transpose->dimensions().end())) {
+ VLOG(10) << "deleting no-op transpose";
+ computation_->ReplaceInstruction(transpose, operand);
+ changed_ = true;
+ return Status::OK();
+ }
+
+ if (HloOpcode::kTranspose == operand->opcode()) {
+ computation_->ReplaceWithNewInstruction(
+ transpose, HloInstruction::CreateTranspose(
+ transpose->shape(), operand->mutable_operand(0),
+ ComposePermutations(operand->dimensions(),
+ transpose->dimensions())));
+ changed_ = true;
+ return Status::OK();
+ }
+
+ if (is_layout_sensitive_ &&
+ TransposeIsBitcast(transpose, valid_bitcast_callback_)) {
+ ReplaceWithBitcast(transpose);
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleConvolution(
+ HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window) {
+ // HandleConvolution tries to replace a convolution with a DOT instruction.
+ //
+ // Only add when bitcasts can be used:
+ // - if bitcasts are not supported, then reshapes could be used but will
+ // end up with another copy.
+ // - if bitcasts are supported, the simplifier will be called again with
+ // bitcasts_ == true.
+
+ // TODO(cwhipkey): b/31337498, make this layout insensitive.
+ if (!is_layout_sensitive_) return Status::OK();
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+ const Shape& input_shape = lhs->shape();
+ const Shape& filter_shape = rhs->shape();
+ const Shape& convolution_shape = convolution->shape();
+ TF_RET_CHECK(LayoutUtil::HasLayout(input_shape));
+ TF_RET_CHECK(LayoutUtil::HasLayout(filter_shape));
+ TF_RET_CHECK(LayoutUtil::HasLayout(convolution_shape));
+
+ // Require 1x1 filter in the spatial dimensions (so no need to extract image
+ // patches).
+ if (filter_shape.dimensions(dnums.kernel_spatial_dimensions(0)) != 1 ||
+ filter_shape.dimensions(dnums.kernel_spatial_dimensions(1)) != 1) {
+ return Status::OK();
+ }
+
+ // Stride ignores part of the output, which matrix multiplication does not do,
+ // so require no stride. Padding and base (lhs) dilation both implicitly
+ // extend the data, which matrix multiplication also does not do, so require
+ // no padding and no base (lhs) dilation. Window (rhs) dilation has no effect
+ // for a 1x1 window, so window dilation is no problem.
+ if (window_util::HasStride(window) || window_util::HasPadding(window) ||
+ window_util::HasBaseDilation(window)) {
+ return Status::OK();
+ }
+
+ // Also, the shapes must align for a rowmajor matmul:
+ // - the input and output have the same layout.
+ // - for input/output, the channel dimension must be the most minor. Other
+ // spatial dims can be in any order.
+ // - for filters, the input channel dimension must be more major than the
+ // output channel dimension. The width+height don't matter because
+ // they are 1.
+ //
+ // These constraints are harsh. If the channel dimension is the most major
+ // and/or the layout of input/output feature dimensions are reversed, we can
+ // still convert Conv into more efficient Matmul with operand transposition
+ // (such as the transposition flags in cuBLAS SGEMM).
+ if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) ||
+ input_shape.layout().minor_to_major(0) != dnums.feature_dimension() ||
+ // The input feature dimension should come later in the minor-to-major
+ // order.
+ (PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
+ dnums.kernel_input_feature_dimension()) <
+ PositionInContainer(AsInt64Slice(filter_shape.layout().minor_to_major()),
+ dnums.kernel_output_feature_dimension()))) {
+ return Status::OK();
+ }
+
+ auto add_bitcast = [&](Shape shape, HloInstruction* operand) {
+ std::vector<int64> dims(operand->shape().dimensions_size());
+ std::iota(dims.begin(), dims.end(), 0);
+ return computation_->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kBitcast, operand));
+ };
+
+ // Replace it with a dot, with bitcasts around it to get the right shape.
+ const int64 input_channels =
+ input_shape.dimensions(dnums.feature_dimension());
+ const int64 output_channels =
+ filter_shape.dimensions(dnums.kernel_output_feature_dimension());
+
+ // Computes the product of the non-feature dimensions.
+ int64 conv_width = 1;
+ for (int i = 0; i < input_shape.dimensions_size(); ++i) {
+ if (i != dnums.feature_dimension()) {
+ conv_width *= input_shape.dimensions(i);
+ }
+ }
+
+ // We already checked feature_dimension is most minor, so data in input_shape
+ // and row-major {conv_width,input_channels} are bitwise identical.
+ const Shape new_input_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ input_shape.element_type(), {conv_width, input_channels});
+ // We already checked input_feature_dimension is more major than
+ // output_feature_dimension, so data in filter_shape and row-major
+ // {input_channels,output_channels} are bitwise identical.
+ const Shape new_filter_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ filter_shape.element_type(), {input_channels, output_channels});
+ const Shape dot_output_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ convolution_shape.element_type(), {conv_width, output_channels});
+
+ // We cannot insert bitcasts if the layouts will not be compatible.
+ // TODO(b/33178038): Consider inserting a transpose if a bitcast would be
+ // invalid.
+ if (!valid_bitcast_callback_(lhs->shape(), input_shape) ||
+ !valid_bitcast_callback_(rhs->shape(), new_filter_shape) ||
+ !valid_bitcast_callback_(dot_output_shape, convolution_shape)) {
+ return Status::OK();
+ }
+
+ auto new_lhs = add_bitcast(new_input_shape, lhs);
+ auto new_rhs = add_bitcast(new_filter_shape, rhs);
+ auto dot = computation_->AddInstruction(HloInstruction::CreateBinary(
+ dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs));
+ computation_->ReplaceInstruction(convolution,
+ add_bitcast(convolution_shape, dot));
+ changed_ = true;
+ return Status::OK();
+}
+
+bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape(
+ HloInstruction* root, HloInstruction* min, HloInstruction* min_operand,
+ HloInstruction* operand, HloInstruction* max, HloInstruction* max_operand) {
+ // Ensure shapes of min and max operand are equal to match current shape
+ // inference.
+ if (!SameShape(min_operand, max_operand)) {
+ return false;
+ }
+
+ auto clamp = HloInstruction::CreateTernary(root->shape(), HloOpcode::kClamp,
+ max_operand, operand, min_operand);
+ computation_->ReplaceWithNewInstruction(root, std::move(clamp));
+ changed_ = true;
+ return true;
+}
+
+Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // Match the following tree:
+ // min_operand operand
+ // \ /
+ // max_operand min
+ // \ /
+ // max
+ // where max_operand and min_operand are scalar constants.
+ {
+ HloInstruction* min;
+ HloInstruction* max_operand;
+ HloInstruction* min_operand;
+ HloInstruction* operand;
+
+ if (hlo_query::MatchBinaryInstructionOperandOpcode(
+ HloOpcode::kMinimum, maximum,
+ /*matching_operand=*/&min,
+ /*other_operand=*/&max_operand) &&
+ hlo_query::MatchBinaryInstructionOperand(
+ hlo_query::IsScalarConstant, min,
+ /*matching_operand=*/&min_operand,
+ /*other_operand=*/&operand) &&
+ TransformToClampIfSameShape(maximum, min, min_operand, operand, maximum,
+ max_operand)) {
+ return Status::OK();
+ }
+ }
+
+ return Status::OK();
+}
+
+Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // Match the following tree:
+ // max_operand operand
+ // \ /
+ // min_operand max
+ // \ /
+ // min
+ // where max_operand and min_operand are scalar constants.
+ {
+ HloInstruction* max;
+ HloInstruction* max_operand;
+ HloInstruction* min_operand;
+ HloInstruction* operand;
+
+ if (hlo_query::MatchBinaryInstructionOperandOpcode(
+ HloOpcode::kMaximum, minimum,
+ /*matching_operand=*/&max,
+ /*other_operand=*/&min_operand) &&
+ hlo_query::MatchBinaryInstructionOperand(
+ hlo_query::IsScalarConstant, max,
+ /*matching_operand=*/&max_operand,
+ /*other_operand=*/&operand) &&
+ TransformToClampIfSameShape(minimum, minimum, min_operand, operand, max,
+ max_operand)) {
+ return Status::OK();
+ }
+ }
+
+ return Status::OK();
+}
+
+StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
+ return std::any_of(
+ module->computations().begin(), module->computations().end(),
+ [=](const std::unique_ptr<HloComputation>& computation) {
+ return AlgebraicSimplifierVisitor::Run(
+ computation.get(), is_layout_sensitive_, valid_bitcast_callback_);
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h
new file mode 100644
index 0000000000..4aa06cab53
--- /dev/null
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// A pass which performs AlgebraicSimplications.
+class AlgebraicSimplifier : public HloPass {
+ public:
+ // Given two shapes, determines if it is valid to bitcast between them.
+ // Precondition: the two shapes have layouts and have the same number of
+ // elements.
+ using ValidBitcastCallback = std::function<bool(const Shape&, const Shape&)>;
+
+ // If is_layout_sensitive is true, then the simplifier preserves layout during
+ // transformation. Otherwise, layout is ignored. If valid_bitcast_callback
+ // returns true, then the pass will replace reshapes and tranposes with
+ // bitcasts.
+ AlgebraicSimplifier(bool is_layout_sensitive,
+ ValidBitcastCallback valid_bitcast_callback)
+ : HloPass("algsimp"),
+ is_layout_sensitive_(is_layout_sensitive),
+ valid_bitcast_callback_(std::move(valid_bitcast_callback)) {}
+ ~AlgebraicSimplifier() override {}
+
+ // Run algebraic simplification on the given computation. Returns whether the
+ // computation was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ bool is_layout_sensitive_;
+ ValidBitcastCallback valid_bitcast_callback_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALGEBRAIC_SIMPLIFIER_H_
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
new file mode 100644
index 0000000000..49ea91f83b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -0,0 +1,1368 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace {
+
+AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() {
+ return [](const Shape&, const Shape&) { return true; };
+}
+AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() {
+ return [](const Shape&, const Shape&) { return false; };
+}
+
+using AlgebraicSimplifierTest = HloTestBase;
+
+// Test that A + 0 is simplified to A
+TEST_F(AlgebraicSimplifierTest, AddZero) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that A - 0 is simplified to A
+TEST_F(AlgebraicSimplifierTest, SubZero) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that A/1 is simplified to A for a scalar.
+TEST_F(AlgebraicSimplifierTest, DivOneScalar) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, div);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that A/1 is simplified to A for an array.
+TEST_F(AlgebraicSimplifierTest, DivOneArray) {
+ Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r2f32, "param0"));
+ HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 1.0}, {1.0, 1.0}})));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, div);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, param0);
+}
+
+// Test that get_element(make_tuple({A,B}),1) is simplified to B
+TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* param2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({param0, param1}));
+ HloInstruction* get = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(r0f32, tuple, 1));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, add);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, add);
+ EXPECT_EQ(root->operand(0), param1);
+ EXPECT_EQ(root->operand(1), param2);
+}
+
+// Test that exp(A)/exp(B) is simplified to exp(A-B)
+TEST_F(AlgebraicSimplifierTest, ExpDiv) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kDivide);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kExp);
+ EXPECT_EQ(root->operand_count(), 1);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kSubtract);
+ EXPECT_EQ(root->operand(0)->operand(0), param0);
+ EXPECT_EQ(root->operand(0)->operand(1), param1);
+}
+
+// Test that ln(exp(A)) is simplified to A
+TEST_F(AlgebraicSimplifierTest, LnExp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kLog);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kParameter);
+ EXPECT_EQ(root, param0);
+}
+
+// Test that ln(exp(A)/exp(B)) is simplified to A-B
+TEST_F(AlgebraicSimplifierTest, LnExpDiv) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* exp0 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param0));
+ HloInstruction* exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kExp, param1));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kLog);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
+ EXPECT_EQ(root->operand(0), param0);
+ EXPECT_EQ(root->operand(1), param1);
+}
+
+// Test that pow(A, 0) where A is a scalar is simplified to the scalar
+// constant 1.
+TEST_F(AlgebraicSimplifierTest, Pow0Scalar) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kConstant);
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->literal()), 1);
+}
+
+// Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1).
+TEST_F(AlgebraicSimplifierTest, Pow0Vector) {
+ Shape r1f32 = ShapeUtil::MakeShape(F32, {42});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), r1f32))
+ << ShapeUtil::HumanString(root->shape());
+ EXPECT_EQ(root->dimensions().size(), 0);
+ EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape()));
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
+ 1);
+}
+
+// Test that pow(A, 1) is simplified to A.
+TEST_F(AlgebraicSimplifierTest, Pow1) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kParameter);
+ EXPECT_EQ(root, param0);
+}
+
+// Test that pow(A, 2) is simplified to A*A.
+TEST_F(AlgebraicSimplifierTest, Pow2) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* two = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
+ EXPECT_EQ(root->operand(0), param0);
+ EXPECT_EQ(root->operand(1), param0);
+}
+
+// Test that pow(A, -1) is simplified to 1/A.
+TEST_F(AlgebraicSimplifierTest, PowNegative1) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* negative_one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(-1)));
+ builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower,
+ param0, negative_one));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kDivide);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kConstant);
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(root->operand(0)->literal()),
+ 1);
+ EXPECT_EQ(root->operand(1), param0);
+}
+
+TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto op = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {3, 2}), "op"));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {6}), op));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1, 6}), reshape1, {1}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {3, 2}), broadcast));
+
+ auto computation = builder.Build();
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(std::move(computation));
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ HloPassFix<AlgebraicSimplifier> simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kParameter);
+}
+
+// Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE.
+TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertF32ToS64) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+ EXPECT_EQ(LiteralUtil::GetFirstElement<int64>(
+ computation->root_instruction()->literal()),
+ 42);
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertS64ToF32) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+ EXPECT_EQ(LiteralUtil::GetFirstElement<float>(
+ computation->root_instruction()->literal()),
+ 42.0f);
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertF32ArrayToS64Array) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
+ builder.AddInstruction(
+ HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kConvert, computation->root_instruction()->opcode());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kConstant, computation->root_instruction()->opcode());
+ EXPECT_EQ(
+ LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {0}),
+ 42);
+ EXPECT_EQ(
+ LiteralUtil::Get<int64>(computation->root_instruction()->literal(), {1}),
+ 19);
+}
+
+// Test that copies are removed.
+TEST_F(AlgebraicSimplifierTest, RemoveCopy) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(copy, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(param0, computation->root_instruction());
+}
+
+// Test that a simplification which changes layouts is not performed if layout
+// sensitive is true.
+TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ HloInstruction* copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ // Set to different layouts.
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+
+ EXPECT_EQ(copy, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Copy has not been removed.
+ EXPECT_EQ(copy, computation->root_instruction());
+}
+
+// Test that a simplification which preserves layouts is performed if layout
+// sensitive is true.
+TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ HloInstruction* copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ // Set to same layouts.
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ *copy->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ EXPECT_EQ(copy, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Copy has been removed.
+ EXPECT_EQ(param0, computation->root_instruction());
+}
+
+// Test that a reshape which could be replaced with a bitcast is not if
+// add_bitcasts is false.
+TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
+
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ *reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(reshape, computation->root_instruction());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Reshape is not replaced with a bitcast.
+ EXPECT_EQ(reshape, computation->root_instruction());
+}
+
+// Test transforming reshapes to bitcasts under various conditions.
+TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+ *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ // Reshape which can be transformed into a bitcast.
+ HloInstruction* transformable_reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
+ *transformable_reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
+
+ // Reshape does not just add degenerate dimensions.
+ HloInstruction* dimensions_wrong_reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 4, 1, 1, 1, 1}), param0));
+ *dimensions_wrong_reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5});
+
+ // Reshape has wrong layout.
+ HloInstruction* layout_wrong_reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), param0));
+ *layout_wrong_reshape->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({5, 4, 3, 2, 1, 0});
+
+ // Collect all the reshapes into a tuple so they are not dead.
+ builder.AddInstruction(HloInstruction::CreateTuple(
+ {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(transformable_reshape, computation->root_instruction()->operand(0));
+ EXPECT_EQ(dimensions_wrong_reshape,
+ computation->root_instruction()->operand(1));
+ EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Verify that only the first reshape is replaced.
+ EXPECT_NE(transformable_reshape, computation->root_instruction()->operand(0));
+ EXPECT_EQ(HloOpcode::kBitcast,
+ computation->root_instruction()->operand(0)->opcode());
+ EXPECT_EQ(dimensions_wrong_reshape,
+ computation->root_instruction()->operand(1));
+ EXPECT_EQ(layout_wrong_reshape, computation->root_instruction()->operand(2));
+}
+
+TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {50, 14, 14, 64}), "param"));
+ *param->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({1, 2, 0, 3});
+
+ HloInstruction* transpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {14, 14, 50, 64}), param, {1, 2, 0, 3}));
+ *transpose->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Verify that the reshape is replaced.
+ EXPECT_EQ(2, computation->instruction_count());
+ EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {5, 2, 3, 4}), "param"));
+ *param->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({1, 2, 3, 0});
+
+ HloInstruction* transpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {5, 3, 4, 2}), param, {0, 2, 3, 1}));
+ *transpose->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({3, 1, 2, 0});
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ // Verify that the reshape is replaced.
+ EXPECT_EQ(2, computation->instruction_count());
+ EXPECT_EQ(HloOpcode::kBitcast, computation->root_instruction()->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, ReshapesMerged) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param0"));
+
+ HloInstruction* reshape1 =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {2, 1, 2}), param0));
+
+ HloInstruction* reshape2 =
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(reshape2, computation->root_instruction());
+ EXPECT_EQ(reshape1, computation->root_instruction()->operand(0));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, TransposesMerged) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 3, 4}), "param0"));
+
+ HloInstruction* transpose1 =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 4, 2}), param0, {1, 2, 0}));
+
+ HloInstruction* transpose2 =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(transpose2, computation->root_instruction());
+ EXPECT_EQ(transpose1, computation->root_instruction()->operand(0));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kTranspose, computation->root_instruction()->opcode());
+ EXPECT_EQ(std::vector<int64>({2, 1, 0}),
+ computation->root_instruction()->dimensions());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+// Test merging reshape and broadcast.
+TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {5}), "param0"));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {1, 5, 1}), param0));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+// Test merging broadcast and reshape.
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 3}), "param0"));
+ auto broadcast1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {1, 2, 3, 7, 12, 1}), param0, {1, 2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ EXPECT_EQ(HloOpcode::kParameter,
+ computation->root_instruction()->operand(0)->opcode());
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 1}), param, {1}));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {4}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2, 4}), param, {2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ EXPECT_MATCH(computation->root_instruction()->dimensions(),
+ testing::VectorMatcher<int64>({3}));
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2, 1}), param, {2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kBroadcast, computation->root_instruction()->opcode());
+ const std::vector<int64> broadcast_dims =
+ computation->root_instruction()->dimensions();
+ EXPECT_EQ(1, broadcast_dims.size());
+ EXPECT_TRUE(broadcast_dims[0] == 1 || broadcast_dims[0] == 2 ||
+ broadcast_dims[3] == 3);
+}
+
+TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) {
+ HloComputation::Builder builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {4}), "param"));
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2, 4, 2}), param, {2}));
+ builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(F32, {6, 8}), broadcast));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+}
+
+TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 2}), "param"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ PaddingConfig no_padding;
+ for (auto i = 0; i < 2; ++i) {
+ auto dimension = no_padding.add_dimensions();
+ dimension->set_edge_padding_low(0);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(0);
+ }
+ builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_EQ(1, computation->instruction_count());
+}
+
+TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {2, 3}), "param"));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_EQ(1, computation->instruction_count());
+}
+
+TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
+ HloComputation::Builder builder(TestName());
+ const int64 dim0 = 2;
+ const int64 dim1 = 3;
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
+ /*limit_indices=*/{dim0, dim1}));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_EQ(1, computation->instruction_count());
+}
+
+TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
+ struct ConvTestOptions {
+ int in_batch = 10;
+ int in_height = 2;
+ int in_width = 2;
+ int in_channels = 3;
+ int f_width = 1;
+ int f_height = 1;
+ int f_output_channels = 10;
+ int row_stride = 1;
+ int row_padding = 0;
+ int col_stride = 1;
+ int col_padding = 0;
+ bool input_minor_to_major_layout = false;
+ bool filter_minor_to_major_layout = false;
+ bool output_minor_to_major_layout = false;
+
+ const char* dim_order = "NHWC"; // can use chars NHWC in any order.
+ const char* kernel_dim_order = "HWIO"; // can use chars HWIO in any order.
+
+ ConvTestOptions& Reset() {
+ *this = ConvTestOptions();
+ return *this;
+ }
+ };
+
+ ConvTestOptions options;
+
+ // Builds a convolution from <options> and runs algebraic simplification on
+ // the computation. Returns a string description of the result of
+ // simplification.
+ auto build_and_simplify = [&options, this]() -> string {
+ HloComputation::Builder b(TestName());
+
+ Window window;
+ auto* f_dim_1 = window.add_dimensions();
+ f_dim_1->set_size(options.f_height);
+ f_dim_1->set_stride(options.row_stride);
+ f_dim_1->set_padding_low(options.row_padding);
+ f_dim_1->set_padding_high(options.row_padding);
+ f_dim_1->set_window_dilation(1);
+ f_dim_1->set_base_dilation(1);
+ auto* f_dim_2 = window.add_dimensions();
+ f_dim_2->set_size(options.f_width);
+ f_dim_2->set_stride(options.col_stride);
+ f_dim_2->set_padding_low(options.col_padding);
+ f_dim_2->set_padding_high(options.col_padding);
+ f_dim_2->set_window_dilation(1);
+ f_dim_2->set_base_dilation(1);
+
+ ConvolutionDimensionNumbers dnums;
+ std::vector<int64> in_dims;
+ int in_channel_idx = -1;
+ dnums.add_spatial_dimensions(-1); // filled in later
+ dnums.add_spatial_dimensions(-1); // filled in later
+ for (int i = 0; i < strlen(options.dim_order); ++i) {
+ char ch = options.dim_order[i];
+ if (ch == 'N') {
+ dnums.set_batch_dimension(i);
+ in_dims.push_back(options.in_batch);
+ } else if (ch == 'H') {
+ dnums.set_spatial_dimensions(0, i);
+ in_dims.push_back(options.in_height);
+ } else if (ch == 'W') {
+ dnums.set_spatial_dimensions(1, i);
+ in_dims.push_back(options.in_width);
+ } else if (ch == 'C') {
+ dnums.set_feature_dimension(i);
+ in_dims.push_back(options.in_channels);
+ in_channel_idx = i;
+ }
+ }
+
+ std::vector<int64> f_dims;
+ dnums.add_kernel_spatial_dimensions(-1); // filled in later
+ dnums.add_kernel_spatial_dimensions(-1); // filled in later
+ for (int i = 0; i < strlen(options.kernel_dim_order); ++i) {
+ char ch = options.kernel_dim_order[i];
+ if (ch == 'H') {
+ dnums.set_kernel_spatial_dimensions(0, i);
+ f_dims.push_back(options.f_height);
+ } else if (ch == 'W') {
+ dnums.set_kernel_spatial_dimensions(1, i);
+ f_dims.push_back(options.f_width);
+ } else if (ch == 'I') {
+ dnums.set_kernel_input_feature_dimension(i);
+ f_dims.push_back(options.in_channels);
+ } else if (ch == 'O') {
+ dnums.set_kernel_output_feature_dimension(i);
+ f_dims.push_back(options.f_output_channels);
+ }
+ }
+
+ auto out_dims = in_dims;
+ out_dims[in_channel_idx] = options.f_output_channels;
+
+ auto make_shape = [](tensorflow::gtl::ArraySlice<int64> dims,
+ bool minor_to_major_layout) {
+ if (minor_to_major_layout) {
+ return ShapeUtil::MakeShapeWithLayout(F32, dims, {0, 1, 2, 3});
+ } else {
+ return ShapeUtil::MakeShape(F32, dims);
+ }
+ };
+ auto in_shape = make_shape(in_dims, options.input_minor_to_major_layout);
+ auto f_shape = make_shape(f_dims, options.filter_minor_to_major_layout);
+ auto out_shape = make_shape(out_dims, options.output_minor_to_major_layout);
+
+ HloInstruction* input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, in_shape, "input"));
+ HloInstruction* filter =
+ b.AddInstruction(HloInstruction::CreateParameter(1, f_shape, "filter"));
+
+ b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
+ window, dnums));
+
+ HloModule module(TestName());
+ auto* computation = module.AddEntryComputation(b.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
+ bitcasting_callback());
+ if (!simplifier.Run(&module).ValueOrDie()) {
+ return "NO_CHANGE";
+ }
+ auto* root = computation->root_instruction();
+ if (root->opcode() == HloOpcode::kBitcast &&
+ root->operand(0)->opcode() == HloOpcode::kDot) {
+ auto lhs_shape = root->operand(0)->operand(0)->shape();
+ auto rhs_shape = root->operand(0)->operand(1)->shape();
+ return tensorflow::strings::StrCat(
+ tensorflow::str_util::Join(lhs_shape.dimensions(), "x"), " DOT ",
+ tensorflow::str_util::Join(rhs_shape.dimensions(), "x"));
+ }
+ return "UNEXPECTED CHANGE";
+ };
+
+ // Default options are the simplest case and succeed.
+ options.Reset();
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+
+ // Swapping dim spatial and batch order works.
+ options.Reset().dim_order = "NWHC";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ options.Reset().dim_order = "WHNC";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ // Channel dimension earlier fails.
+ options.Reset().dim_order = "HWCN";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().dim_order = "CHWN";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // Filtering dims spatial dims can be anywhere, since they are 1x1.
+ options.Reset().kernel_dim_order = "WHIO";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ options.Reset().kernel_dim_order = "IWOH";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ options.Reset().kernel_dim_order = "IWHO";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+ // But moving output channel before input channel fails.
+ options.Reset().kernel_dim_order = "HWOI";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().kernel_dim_order = "WHOI";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().kernel_dim_order = "OWIH";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().kernel_dim_order = "OWHI";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // Combine different dim and kernel dim orders.
+ options.Reset().kernel_dim_order = "IWHO";
+ options.dim_order = "WHNC";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+
+ // Test invalid cases from wrong filter size, strides, or padding.
+ options.Reset().f_width = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().f_height = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().row_stride = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().col_stride = 2;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().col_padding = 1;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+ options.Reset().row_padding = 1;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // The default dim_order is "NHWC". Col-major layout makes C the most major.
+ options.Reset().input_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // The input and output have different layouts.
+ options.Reset().input_minor_to_major_layout = true;
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // C is most minor, and I is more major than O.
+ options.Reset().input_minor_to_major_layout = true;
+ options.filter_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ options.dim_order = "CHWN";
+ options.kernel_dim_order = "OIHW";
+ EXPECT_EQ("40x3 DOT 3x10", build_and_simplify());
+
+ // C is not the most minor dimension.
+ options.Reset().input_minor_to_major_layout = true;
+ options.filter_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ options.dim_order = "HWNC";
+ options.kernel_dim_order = "OIHW";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+
+ // I is more minor than O.
+ options.Reset().input_minor_to_major_layout = true;
+ options.filter_minor_to_major_layout = true;
+ options.output_minor_to_major_layout = true;
+ options.dim_order = "CHWN";
+ options.kernel_dim_order = "IOHW";
+ EXPECT_EQ("NO_CHANGE", build_and_simplify());
+}
+
+// Test that max(min(A, x), y) is transformed to clamp(y, A, x)
+TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMinimum, param0, min_value));
+ HloInstruction* max = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, max);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ ASSERT_EQ(root->opcode(), HloOpcode::kClamp);
+ EXPECT_EQ(root->operand(0), max_value);
+ EXPECT_EQ(root->operand(1), param0);
+ EXPECT_EQ(root->operand(2), min_value);
+}
+
+// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for scalar
+// values.
+TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kClamp);
+ EXPECT_EQ(root->operand(0), max_value);
+ EXPECT_EQ(root->operand(1), param0);
+ EXPECT_EQ(root->operand(2), min_value);
+}
+
+// Test that min(max(A, x), y) is transformed to clamp(x, A, y) for
+// broadcasted scalar values.
+TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ Shape r1f32 = ShapeUtil::MakeShape(F32, {100});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kClamp);
+ EXPECT_EQ(root->operand(0), max_value);
+ EXPECT_EQ(root->operand(1), param0);
+ EXPECT_EQ(root->operand(2), min_value);
+}
+
+// Test that min(max(A, non-constant1), non-constant2) is not canonicalized to
+// clamp(non-constant1, A, non-constant2)
+TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "param1"));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, r0f32, "param2"));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+}
+
+// Test that min(f(max(A, constant1)), constant2) is not transformed to
+// clamp(constant1, A, constant2)
+TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* min_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ HloInstruction* max_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMaximum, param0, max_value));
+ HloInstruction* fmax = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, max, max_value));
+ HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kMinimum, fmax, min_value));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_EQ(root, min);
+}
+
+// Test that slice(broadcast(/*scalar value*/)) simplifies to a single
+// broadcast.
+TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* scalar_param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
+
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ broadcast_shape, scalar_param,
+ AsInt64Slice(broadcast_shape.dimensions())));
+
+ Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
+ HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
+ slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, slice);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ EXPECT_EQ(scalar_param, root->operand(0));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), slice_shape));
+}
+
+// Test that reshape(transpose(broadcast(/*scalar value*/))) simplifies to a
+// single broadcast.
+TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* forty_two = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+
+ Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ broadcast_shape, forty_two,
+ AsInt64Slice(broadcast_shape.dimensions())));
+
+ HloInstruction* transpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {6, 5, 4}), broadcast, {2, 1, 0}));
+
+ Shape reshape_shape = ShapeUtil::MakeShape(F32, {30, 1, 4});
+ HloInstruction* reshape = builder.AddInstruction(
+ HloInstruction::CreateReshape(reshape_shape, transpose));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, reshape);
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+ EXPECT_EQ(forty_two, root->operand(0));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reshape_shape));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
new file mode 100644
index 0000000000..a123213401
--- /dev/null
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -0,0 +1,215 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/allocation_tracker.h"
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+AllocationTracker::AllocationTracker() : next_handle_(1) {}
+
+GlobalDataHandle AllocationTracker::Register(Backend* backend,
+ int device_ordinal,
+ se::DeviceMemoryBase device_memory,
+ const Shape& shape,
+ const string& tag) {
+ tensorflow::mutex_lock lock(allocation_mutex_);
+ VLOG(2) << "Register";
+ return RegisterInternal(backend, device_ordinal, device_memory, shape, tag,
+ /*initial_ref_count=*/1);
+}
+
+GlobalDataHandle AllocationTracker::RegisterInternal(
+ Backend* backend, int device_ordinal, se::DeviceMemoryBase device_memory,
+ const Shape& shape, const string& tag, int initial_ref_count) {
+ VLOG(2) << "RegisterInternal("
+ << "tag: \"" << tag << "\" "
+ << "device_ordinal: " << device_ordinal << " "
+ << "device_memory: " << device_memory.opaque() << " "
+ << "shape: " << shape.ShortDebugString() << ")";
+ TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
+
+ int64 handle;
+ HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal);
+ auto handle_it = handle_map.find(device_memory.opaque());
+ if (handle_it != handle_map.end()) {
+ handle = handle_it->second;
+ auto& allocation = FindOrDie(handle_to_allocation_, handle);
+ int ref_count = allocation->ref_count();
+ CHECK_GT(ref_count, 0);
+ VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count + 1;
+ allocation->increment_ref_count();
+ } else {
+ handle = next_handle_++;
+ VLOG(2) << "ref_count: " << initial_ref_count;
+ InsertOrDie(&handle_map, device_memory.opaque(), handle);
+ auto inserted = handle_to_allocation_.emplace(
+ handle, MakeUnique<Allocation>(backend, device_ordinal, device_memory,
+ shape, tag, initial_ref_count));
+ CHECK(inserted.second);
+ }
+
+ GlobalDataHandle result;
+ result.set_handle(handle);
+ VLOG(2) << "handle: " << handle;
+
+ return result;
+}
+
+tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
+ tensorflow::mutex_lock lock(allocation_mutex_);
+ TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data));
+ std::set<void*> deallocated_buffers;
+ TF_RETURN_IF_ERROR(
+ DeallocateShape(allocation->backend(), allocation->device_ordinal(),
+ allocation->mutable_device_memory(), allocation->shape(),
+ &deallocated_buffers));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status AllocationTracker::DeallocateShape(
+ Backend* backend, int device_ordinal, se::DeviceMemoryBase* device_memory,
+ const Shape& shape, std::set<void*>* deallocated_buffers) {
+ VLOG(2) << "DeallocateShape("
+ << "shape: \"" << shape.ShortDebugString() << "\" "
+ << "device_memory: " << device_memory->opaque() << ")";
+ if (ContainsKey(*deallocated_buffers, device_memory->opaque())) {
+ // Buffer has already been deallocated. Nothing to do.
+ VLOG(2) << "already deallocated";
+ return tensorflow::Status::OK();
+ }
+
+ // Add buffer to deallocated set so we do not try to deallocate it again
+ // if it is encountered again while traversing a tuple.
+ deallocated_buffers->insert(device_memory->opaque());
+
+ HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal);
+ auto handle_it = handle_map.find(device_memory->opaque());
+ if (handle_it != handle_map.end()) {
+ int64 handle = handle_it->second;
+ auto& allocation = FindOrDie(handle_to_allocation_, handle);
+ int ref_count = allocation->ref_count();
+ VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count - 1;
+ allocation->decrement_ref_count();
+ if (allocation->ref_count() > 0) {
+ // Buffer is referred to by another allocation. Don't deallocate it.
+ return tensorflow::Status::OK();
+ }
+ handle_map.erase(device_memory->opaque());
+ }
+
+ if (ShapeUtil::IsTuple(shape)) {
+ // Traverse into tuple recursively deallocating buffers.
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
+ backend->stream_executor(device_ordinal));
+ TF_ASSIGN_OR_RETURN(std::vector<se::DeviceMemoryBase> elements,
+ backend->transfer_manager()->ShallowCopyTupleFromDevice(
+ executor, *device_memory, shape));
+
+ TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size())
+ << "tuple has unexpected number of elements: " << elements.size()
+ << " != " << ShapeUtil::TupleElementCount(shape);
+ for (int i = 0; i < elements.size(); ++i) {
+ VLOG(2) << "recursing onto the tuple elements";
+ TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i],
+ shape.tuple_shapes(i),
+ deallocated_buffers));
+ }
+ }
+
+ return backend->memory_allocator()->Deallocate(device_ordinal, device_memory);
+}
+
+StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
+ const GlobalDataHandle& data) {
+ tensorflow::mutex_lock lock(allocation_mutex_);
+ TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data));
+
+ if (!ShapeUtil::IsTuple(allocation->shape())) {
+ return InvalidArgument("global data handle %lld is not a tuple",
+ data.handle());
+ }
+
+ if (ShapeUtil::IsNestedTuple(allocation->shape())) {
+ return Unimplemented("deconstructing nested tuples not yet supported");
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ allocation->backend()->stream_executor(allocation->device_ordinal()));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_bases,
+ allocation->backend()->transfer_manager()->ShallowCopyTupleFromDevice(
+ executor, allocation->device_memory(), allocation->shape()));
+
+ std::vector<GlobalDataHandle> element_handles;
+ for (int i = 0; i < element_bases.size(); ++i) {
+ element_handles.push_back(RegisterInternal(
+ allocation->backend(), allocation->device_ordinal(), element_bases[i],
+ ShapeUtil::GetSubshape(allocation->shape(), {i}),
+ tensorflow::strings::StrCat(allocation->tag(), ".element_", i),
+ /*initial_ref_count=*/2));
+ }
+ return std::move(element_handles);
+}
+
+StatusOr<const Allocation*> AllocationTracker::Resolve(
+ const GlobalDataHandle& data) {
+ tensorflow::mutex_lock lock(allocation_mutex_);
+ return AllocationTracker::ResolveInternal(data);
+}
+
+StatusOr<Allocation*> AllocationTracker::ResolveInternal(
+ const GlobalDataHandle& data) {
+ VLOG(2) << "resolve:" << data.handle();
+ auto it = handle_to_allocation_.find(data.handle());
+ if (it == handle_to_allocation_.end()) {
+ return NotFound("no allocation record for global data handle: %lld",
+ data.handle());
+ }
+ Allocation* allocation = it->second.get();
+
+ if (allocation->is_deallocated()) {
+ return InvalidArgument("global data handle %lld was previously deallocated",
+ data.handle());
+ }
+
+ return allocation;
+}
+
+AllocationTracker::HandleMap& AllocationTracker::GetOrCreateOpaqueToHandleMap(
+ int device_ordinal) {
+ if (opaque_to_handle_.size() <= device_ordinal) {
+ opaque_to_handle_.resize(device_ordinal + 1);
+ }
+ return opaque_to_handle_[device_ordinal];
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
new file mode 100644
index 0000000000..e007680016
--- /dev/null
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -0,0 +1,178 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
+
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A global allocation in device space, tracked by the XLA service.
+class Allocation {
+ public:
+ Allocation(Backend* backend, int device_ordinal,
+ perftools::gputools::DeviceMemoryBase device_memory,
+ const Shape& shape, const string& tag, int initial_ref_count)
+ : backend_(backend),
+ device_ordinal_(device_ordinal),
+ device_memory_(device_memory),
+ shape_(shape),
+ tag_(tag),
+ ref_count_(initial_ref_count) {}
+
+ Backend* backend() const { return backend_; }
+ int device_ordinal() const { return device_ordinal_; }
+ perftools::gputools::DeviceMemoryBase device_memory() const {
+ return device_memory_;
+ }
+ const Shape& shape() const { return shape_; }
+ const string& tag() const { return tag_; }
+
+ bool is_deallocated() const {
+ CHECK_GE(ref_count_, 0);
+ return ref_count_ == 0;
+ }
+ int ref_count() const {
+ CHECK_GE(ref_count_, 0);
+ return ref_count_;
+ }
+ void increment_ref_count() {
+ CHECK_GT(ref_count_, 0);
+ CHECK_LT(ref_count_, INT_MAX);
+ ++ref_count_;
+ }
+ void decrement_ref_count() {
+ CHECK_GT(ref_count_, 0);
+ --ref_count_;
+ }
+ perftools::gputools::DeviceMemoryBase* mutable_device_memory() {
+ return &device_memory_;
+ }
+
+ private:
+ // The backend that the memory is allocated on.
+ Backend* backend_;
+
+ // The device that the memory is allocated on.
+ int device_ordinal_;
+
+ // The pointer to this allocation.
+ perftools::gputools::DeviceMemoryBase device_memory_;
+
+ // The shape of this allocation.
+ Shape shape_;
+
+ // An informal description of this allocation shown in tools.
+ string tag_;
+
+ // This is the number of Allocation objects which refer to this memory
+ // allocation.
+ int ref_count_;
+
+ // Return a string representation of this allocation for debugging or logging
+ // purposes.
+ string ToString() const;
+};
+
+// Tracks allocations for the XLA service; allocations can be registered
+// with shape/device/tag and resolved from a handle for later use.
+class AllocationTracker {
+ public:
+ AllocationTracker();
+
+ // Registers device memory with a given shape, device identifier, and tag, and
+ // returns a corresponding handle that can be used for talking to XLA
+ // clients.
+ GlobalDataHandle Register(Backend* backend, int device_ordinal,
+ perftools::gputools::DeviceMemoryBase device_memory,
+ const Shape& shape, const string& tag);
+
+ // Unregister the allocation for the given data handle.
+ tensorflow::Status Unregister(const GlobalDataHandle& data);
+
+ // Returns a vector of global data handles that point to the tuple elements.
+ StatusOr<std::vector<GlobalDataHandle>> DeconstructTuple(
+ const GlobalDataHandle& Data);
+
+ // Resolve a handle from an XLA client to an allocation, or provide an
+ // error status to say whether it was not found (or found, but found
+ // deallocated).
+ StatusOr<const Allocation*> Resolve(const GlobalDataHandle& data);
+
+ private:
+ // Internal helper which resolves the given GlobalDataHandle to an Allocation.
+ StatusOr<Allocation*> ResolveInternal(const GlobalDataHandle& data)
+ EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_);
+
+ GlobalDataHandle RegisterInternal(
+ Backend* backend, int device_ordinal,
+ perftools::gputools::DeviceMemoryBase device_memory, const Shape& shape,
+ const string& tag, int initial_ref_count)
+ EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_);
+
+ // Helper function which deallocates the memory buffer containing the given
+ // shape referred to by device_memory. Tuples are traversed recursively
+ // deallocating all nested buffers. The parameter deallocated_buffers contains
+ // the set of buffers deallocated so far stored as opaque values (void *) from
+ // DeviceMemoryBase. Keeping track of deallocated buffers prevents
+ // double-freeing of buffers which may be referred to more than once in a
+ // nested tuple.
+ tensorflow::Status DeallocateShape(
+ Backend* backend, int device_ordinal,
+ perftools::gputools::DeviceMemoryBase* device_memory, const Shape& shape,
+ std::set<void*>* deallocated_buffers)
+ EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_);
+
+ // Returns the opaque_to_handle_ map for the given device_ordinal, creating
+ // a new map if there is not one for the device_ordinal.
+ using HandleMap = std::map<void*, int64>;
+ HandleMap& GetOrCreateOpaqueToHandleMap(int device_ordinal)
+ EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_);
+
+ tensorflow::mutex allocation_mutex_; // Guards the allocation mapping.
+
+ // The next handle to assign to an allocation, guarded by the same mutex as
+ // the mapping as they'll be mutated at the same time.
+ int64 next_handle_ GUARDED_BY(allocation_mutex_);
+
+ // A map from DeviceMemoryBase to handle for each device_ordinal.
+ std::vector<HandleMap> opaque_to_handle_ GUARDED_BY(allocation_mutex_);
+
+ // Mapping from GlobalDataHandle handle to the corresponding registered
+ // Allocation object.
+ std::map<int64, std::unique_ptr<Allocation>> handle_to_allocation_
+ GUARDED_BY(allocation_mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ALLOCATION_TRACKER_H_
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
new file mode 100644
index 0000000000..6e76c98c9f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -0,0 +1,237 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/backend.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+// Define this in .cc file to avoid having to include eigen or forward declare
+// these types in the header.
+struct Backend::EigenThreadPoolWrapper {
+ explicit EigenThreadPoolWrapper()
+ : pool(new tensorflow::thread::ThreadPool(
+ tensorflow::Env::Default(), "XLAEigen",
+ tensorflow::port::NumSchedulableCPUs())),
+ wrapper(new tensorflow::EigenThreadPoolWrapper(pool.get())),
+ device(new Eigen::ThreadPoolDevice(wrapper.get(),
+ wrapper->NumThreads())) {}
+
+ std::unique_ptr<tensorflow::thread::ThreadPool> pool;
+ std::unique_ptr<tensorflow::EigenThreadPoolWrapper> wrapper;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
+};
+
+/* static */ StatusOr<std::unique_ptr<Backend>> Backend::CreateBackend(
+ perftools::gputools::Platform* platform, int64 replica_count) {
+ if (replica_count == -1) {
+ legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags();
+ replica_count = flags->xla_replicas;
+ }
+ TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
+ TF_ASSIGN_OR_RETURN(auto stream_executors,
+ PlatformUtil::GetStreamExecutors(platform));
+ TF_ASSIGN_OR_RETURN(auto transfer_manager,
+ TransferManager::GetForPlatform(platform));
+ std::unique_ptr<Backend> backend(new Backend(
+ replica_count, platform, compiler, stream_executors, transfer_manager));
+ TF_RETURN_IF_ERROR(backend->PoolStreams(kInitialStreamsToPool,
+ backend->default_stream_executor()));
+ return std::move(backend);
+}
+
+/* static */ StatusOr<std::unique_ptr<Backend>>
+Backend::CreateDefaultBackend() {
+ TF_ASSIGN_OR_RETURN(se::Platform * platform,
+ PlatformUtil::GetDefaultPlatform());
+ return CreateBackend(platform);
+}
+
+tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) {
+ std::vector<std::unique_ptr<se::Stream>> primed;
+ for (int i = 0; i < n; ++i) {
+ TF_ASSIGN_OR_RETURN(auto stream, AcquireStream(executor));
+ primed.emplace_back(std::move(stream));
+ }
+ for (int i = 0; i < n; ++i) {
+ ReleaseStream(std::move(primed.back()));
+ primed.pop_back();
+ }
+ return tensorflow::Status::OK();
+}
+
+StatusOr<std::unique_ptr<perftools::gputools::Stream>> Backend::AcquireStream(
+ perftools::gputools::StreamExecutor* executor) {
+ tensorflow::mutex_lock lock(mutex_);
+ auto& cached_streams = cached_streams_[executor];
+ if (!cached_streams.empty()) {
+ auto result = std::move(cached_streams.back());
+ cached_streams.pop_back();
+ return std::move(result);
+ }
+
+ auto stream = MakeUnique<se::Stream>(executor);
+ if (!stream->Init().ok()) {
+ return InternalError("failed to initialize stream");
+ }
+ return std::move(stream);
+}
+
+void Backend::ReleaseStream(
+ std::unique_ptr<perftools::gputools::Stream> stream) {
+ tensorflow::mutex_lock lock(mutex_);
+ auto& streams = cached_streams_[stream->parent()];
+ streams.emplace_back(std::move(stream));
+}
+
+Backend::Backend(
+ int64 replica_count, perftools::gputools::Platform* platform,
+ Compiler* compiler,
+ tensorflow::gtl::ArraySlice<se::StreamExecutor*> stream_executors,
+ TransferManager* transfer_manager)
+ : platform_(platform),
+ compiler_(compiler),
+ transfer_manager_(transfer_manager),
+ replica_count_(replica_count) {
+ // The given set of stream executors set may include invalid executors.
+ for (se::StreamExecutor* exec : stream_executors) {
+ if (exec != nullptr) {
+ stream_executors_.push_back(exec);
+ }
+ }
+ CHECK_GE(replica_count, 1) << "Must request at least 1 replica.";
+
+ // Create a memory allocator for the valid stream executors.
+ memory_allocator_ =
+ MakeUnique<StreamExecutorMemoryAllocator>(platform, stream_executors);
+
+ // First check that there are some non-null stream executors to avoid issuing
+ // an error mentioning replicas in the common case of requesting just 1
+ // replica, which means no replication.
+ CHECK(!stream_executors_.empty())
+ << "Service found no devices for backend " << platform_->Name() << '.';
+ CHECK_GE(stream_executors_.size(), replica_count)
+ << "Requested more replicas than there are devices for backend "
+ << platform_->Name() << '.';
+
+ if (platform->id() == se::host::kHostPlatformId) {
+ inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool(
+ tensorflow::Env::Default(), "xla_inter_op",
+ tensorflow::port::NumSchedulableCPUs()));
+ intra_op_thread_pool_wrapper_.reset(new EigenThreadPoolWrapper());
+ }
+}
+
+Backend::~Backend() {}
+
+int Backend::default_device_ordinal() const {
+ return default_stream_executor()->device_ordinal();
+}
+
+StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Backend::Replicas(
+ int device_ordinal) const {
+ if (stream_executors_[device_ordinal] == nullptr) {
+ return InvalidArgument("device %s not supported by XLA service",
+ device_name(device_ordinal).c_str());
+ }
+
+ // Find replica_count_ stream executors starting from the given device
+ // ordinal.
+ std::vector<perftools::gputools::StreamExecutor*> replicas;
+ for (se::StreamExecutor* exec : stream_executors_) {
+ CHECK(exec != nullptr);
+ if (exec->device_ordinal() >= device_ordinal) {
+ replicas.push_back(exec);
+ if (replicas.size() >= replica_count_) {
+ return replicas;
+ }
+ }
+ }
+
+ return InvalidArgument(
+ "Not enough devices for replicas for the device ordinal %d",
+ device_ordinal);
+}
+
+std::vector<perftools::gputools::StreamExecutor*> Backend::Replicas() const {
+ CHECK_GE(stream_executors_.size(), replica_count_);
+ return Replicas(default_device_ordinal()).ValueOrDie();
+}
+
+tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const {
+ return inter_op_thread_pool_.get();
+}
+
+const Eigen::ThreadPoolDevice* Backend::eigen_intra_op_thread_pool_device()
+ const {
+ if (intra_op_thread_pool_wrapper_ == nullptr) return nullptr;
+ return intra_op_thread_pool_wrapper_->device.get();
+}
+
+StatusOr<perftools::gputools::StreamExecutor*> Backend::stream_executor(
+ int device_ordinal) const {
+ if (device_ordinal < 0 ||
+ device_ordinal > stream_executors_.back()->device_ordinal()) {
+ return InvalidArgument(
+ "Invalid device ordinal value (%d). Valid range is [0, %d].",
+ device_ordinal, stream_executors_.back()->device_ordinal());
+ }
+ for (auto* executor : stream_executors_) {
+ if (executor->device_ordinal() == device_ordinal) {
+ return executor;
+ }
+ }
+ return InvalidArgument("device %s not supported by XLA service",
+ device_name(device_ordinal).c_str());
+}
+
+StatusOr<bool> Backend::devices_equivalent(int device_ordinal_a,
+ int device_ordinal_b) {
+ // Use the name from device description to determine equivalence. This is a
+ // bit crude but works for GPUs which is the important case where we compile
+ // an executable for one GPU and want to know if it will run (well) on
+ // another.
+ TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_a,
+ stream_executor(device_ordinal_a));
+ TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * executor_b,
+ stream_executor(device_ordinal_b));
+ return (executor_a->GetDeviceDescription().name() ==
+ executor_b->GetDeviceDescription().name());
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
new file mode 100644
index 0000000000..17c53d299e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -0,0 +1,191 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace Eigen {
+class ThreadPoolDevice;
+}
+
+namespace xla {
+
+// Class which encapsulates an XLA backend. It includes everything necessary
+// to compile and execute computations on a particular platform.
+//
+// It also offers a pooling API for creation/use of initialized streams:
+//
+// std::unique_ptr<se::Stream> stream =
+// backend->AcquireStream().ConsumeValueOrDie();
+// // ... use stream ...
+// backend->ReleaseStream(std::move(stream));
+class Backend {
+ public:
+ // The number of streams we create for the pool at initialization time.
+ static constexpr int kInitialStreamsToPool = 8;
+
+ // Creates a new backend for the given platform with the given number of
+ // replicas. A value of -1 means to use the flag value.
+ static StatusOr<std::unique_ptr<Backend>> CreateBackend(
+ perftools::gputools::Platform* platform, int64 replica_count = -1);
+
+ // Creates a backend for the default platform. The default platform is defined
+ // in PlatformUtil.
+ static StatusOr<std::unique_ptr<Backend>> CreateDefaultBackend();
+
+ ~Backend();
+
+ // Accessors for the various objects.
+ perftools::gputools::Platform* platform() const { return platform_; }
+ Compiler* compiler() const { return compiler_; }
+ DeviceMemoryAllocator* memory_allocator() const {
+ return memory_allocator_.get();
+ }
+ TransferManager* transfer_manager() const { return transfer_manager_; }
+
+ // Returns the number of devices of the platform type which are visible. Not
+ // all of these devices may be usable by XLA.
+ int device_count() const { return stream_executors_.size(); }
+
+ // Returns the device ordinal number of the default device.
+ int default_device_ordinal() const;
+
+ // Returns stream executors of all supported devices for this backend. The
+ // executors are ordered by the device ordinal.
+ const std::vector<perftools::gputools::StreamExecutor*>& stream_executors()
+ const {
+ return stream_executors_;
+ }
+
+ // Returns the replicas for the default stream executor.
+ //
+ // When the number of replicas is R, the first R stream executors are assigned
+ // to the replicas of the default stream executor.
+ std::vector<perftools::gputools::StreamExecutor*> Replicas() const;
+
+ // Returns the replicas for the given device_ordinal. The given device ordinal
+ // is considered to be the first device ordinal among the replicas. Returns an
+ // error status if the stream executor for the given given device ordinal does
+ // not exist or if there are not enough stream executors for the replicas.
+ StatusOr<std::vector<perftools::gputools::StreamExecutor*>> Replicas(
+ int device_ordinal) const;
+
+ // Return the stream executor for the given device ordinal.
+ StatusOr<perftools::gputools::StreamExecutor*> stream_executor(
+ int device_ordinal) const;
+
+ // Return the stream executor for the default device ordinal.
+ perftools::gputools::StreamExecutor* default_stream_executor() const {
+ CHECK(!stream_executors_.empty());
+ return stream_executors_[0];
+ }
+
+ // Primes the internal pool of streams for AcquireStream/ReleaseStream with n
+ // initialized stream instances.
+ tensorflow::Status PoolStreams(int n,
+ perftools::gputools::StreamExecutor* executor);
+
+ // Acquires a stream for use by the caller, either by grabbing it from an
+ // internal pool, or by constructing/initializating it, and returns the result
+ // to the caller.
+ //
+ // TODO(b/32989582): Return std::unique_ptr with custom deleter.
+ StatusOr<std::unique_ptr<perftools::gputools::Stream>> AcquireStream(
+ perftools::gputools::StreamExecutor* executor);
+
+ // Releases a stream from the caller to the internal pool, for use with the
+ // paired AcquireStream above.
+ void ReleaseStream(std::unique_ptr<perftools::gputools::Stream> stream);
+
+ // Returns whether the given device ordinal of the backend is supported.
+ bool device_ordinal_supported(int device_ordinal) const {
+ return (device_ordinal >= 0 && device_ordinal < device_count() &&
+ stream_executors_[device_ordinal] != nullptr);
+ }
+
+ // Return a string identifier for the given device, eg: "GPU:3".
+ string device_name(int device_ordinal) const {
+ return tensorflow::strings::StrCat(platform_->Name(), ":", device_ordinal);
+ }
+
+ // Returns true if the devices with the given ordinals are equivalent from
+ // XLA's perspective. That is, an executable compiled for one device would
+ // be equivalent to an executable compiled for the other.
+ StatusOr<bool> devices_equivalent(int device_ordinal_a, int device_ordinal_b);
+
+ // For the host platform, returns the threadpool to use when scheduling
+ // parallel operators. For other platforms, returns NULL.
+ tensorflow::thread::ThreadPool* inter_op_thread_pool() const;
+
+ // For the host platform, returns the configured eigen threadpool device to be
+ // used for scheduling work. For other platforms, returns NULL.
+ const Eigen::ThreadPoolDevice* eigen_intra_op_thread_pool_device() const;
+
+ private:
+ struct EigenThreadPoolWrapper;
+ Backend(int64 replica_count, perftools::gputools::Platform* platform,
+ Compiler* compiler,
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ stream_executors,
+ TransferManager* transfer_manager);
+ Backend(const Backend&) = delete;
+ Backend& operator=(const Backend&) = delete;
+
+ perftools::gputools::Platform* platform_;
+ Compiler* compiler_;
+ TransferManager* transfer_manager_;
+ int64 replica_count_ = -1;
+
+ // Vector of stream executors. stream_executors_[0] is the default executor.
+ std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
+
+ // Guards the mutable state in the backend object.
+ tensorflow::mutex mutex_;
+
+ // Mapping from stream executor to cached streams, used by
+ // AcquireStream/ReleaseStream above.
+ std::map<perftools::gputools::StreamExecutor*,
+ std::vector<std::unique_ptr<perftools::gputools::Stream>>>
+ cached_streams_ GUARDED_BY(mutex_);
+
+ // The default memory allocator to use.
+ std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;
+
+ // For the CPU backend, a threadpool for scheduling parallel operators.
+ std::unique_ptr<tensorflow::thread::ThreadPool> inter_op_thread_pool_;
+
+ // For the CPU backend, an Eigen threadpool device for use by Eigen code.
+ std::unique_ptr<EigenThreadPoolWrapper> intra_op_thread_pool_wrapper_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BACKEND_H_
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
new file mode 100644
index 0000000000..1616a1363d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -0,0 +1,777 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Defines the data returned by the XLA buffer assignment packages.
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+
+#include <algorithm>
+#include <deque>
+#include <ostream>
+#include <utility>
+
+#include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace xla {
+
+void BufferAllocation::AddAssignment(const LogicalBuffer& buffer) {
+ DCHECK(std::find(assigned_buffers_.begin(), assigned_buffers_.end(),
+ &buffer) == assigned_buffers_.end())
+ << "LogicalBuffer " << buffer.ToString()
+ << " already assigned to allocation " << index();
+ assigned_buffers_.push_back(&buffer);
+}
+
+string BufferAllocation::ToString() const {
+ string output;
+ tensorflow::strings::StrAppend(
+ &output, tensorflow::strings::Printf("allocation %lld: %p, size %lld",
+ index_, this, size()));
+ if (is_entry_computation_parameter()) {
+ tensorflow::strings::StrAppend(&output, ", parameter ", parameter_number());
+ }
+ if (is_thread_local()) {
+ tensorflow::strings::StrAppend(&output, ", thread-local");
+ }
+ tensorflow::strings::StrAppend(&output, ":\n");
+ for (const auto& buffer : assigned_buffers()) {
+ tensorflow::strings::StrAppend(
+ &output,
+ tensorflow::strings::Printf(
+ " %s::%s : %s\n", buffer->instruction()->parent()->name().c_str(),
+ buffer->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str()));
+ }
+ return output;
+}
+
+std::ostream& operator<<(std::ostream& out, const BufferAllocation& buffer) {
+ out << buffer.ToString();
+ return out;
+}
+
+const PointsToSet& BufferAssignment::GetPointsToSet(
+ const HloInstruction* instruction) const {
+ return points_to_analysis().GetPointsToSet(instruction);
+}
+
+bool BufferAssignment::HasAllocation(const LogicalBuffer& buffer) const {
+ TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
+ return allocation_index_for_buffer_.count(&buffer) > 0;
+}
+
+const BufferAllocation& BufferAssignment::GetAssignedAllocation(
+ const LogicalBuffer& buffer) const {
+ CHECK(HasAllocation(buffer));
+ return GetAllocation(allocation_index_for_buffer_.at(&buffer));
+}
+
+BufferAllocation* BufferAssignment::GetMutableAssignedAllocation(
+ const LogicalBuffer& buffer) {
+ return const_cast<BufferAllocation*>(&GetAssignedAllocation(buffer));
+}
+
+std::set<BufferAllocation> BufferAssignment::GetAllocations(
+ const HloInstruction* instruction, const ShapeIndex& index) const {
+ std::set<BufferAllocation> allocations;
+ for (const LogicalBuffer* buffer : GetSourceBuffers(instruction, index)) {
+ if (allocation_index_for_buffer_.count(buffer) > 0) {
+ allocations.insert(
+ GetAllocation(allocation_index_for_buffer_.at(buffer)));
+ }
+ }
+ return allocations;
+}
+
+const BufferAllocation& BufferAssignment::GetAllocation(
+ BufferAllocation::Index index) const {
+ CHECK(index >= 0 && index < allocations_.size())
+ << "Allocation index " << index << "is out of range.";
+ return allocations_[index];
+}
+
+BufferAllocation* BufferAssignment::GetMutableAllocation(
+ BufferAllocation::Index index) {
+ return const_cast<BufferAllocation*>(&GetAllocation(index));
+}
+
+bool BufferAssignment::HasTopLevelAllocation(
+ const HloInstruction* instruction) const {
+ for (const LogicalBuffer* buffer :
+ GetPointsToSet(instruction).element(/*index=*/{})) {
+ if (allocation_index_for_buffer_.count(buffer) > 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+StatusOr<const BufferAllocation*> BufferAssignment::GetUniqueAllocation(
+ const HloInstruction* instruction, const ShapeIndex& index) const {
+ const BufferAllocation* allocation = nullptr;
+ for (const LogicalBuffer* buffer :
+ GetPointsToSet(instruction).element(index)) {
+ if (HasAllocation(*buffer)) {
+ if (allocation != nullptr &&
+ *allocation != GetAssignedAllocation(*buffer)) {
+ return FailedPrecondition(
+ "LogicalBuffer allocation for instruction %s at index {%s} cannot "
+ "be determined at compile-time.",
+ instruction->name().c_str(),
+ tensorflow::str_util::Join(index, ",").c_str());
+ }
+ allocation = &GetAssignedAllocation(*buffer);
+ }
+ }
+ if (allocation == nullptr) {
+ return FailedPrecondition(
+ "instruction %s has no buffer allocation at index {%s}",
+ instruction->name().c_str(),
+ tensorflow::str_util::Join(index, ",").c_str());
+ }
+ return allocation;
+}
+
+StatusOr<const BufferAllocation*> BufferAssignment::GetUniqueTopLevelAllocation(
+ const HloInstruction* instruction) const {
+ return GetUniqueAllocation(instruction, /*index=*/{});
+}
+
+StatusOr<const BufferAllocation*>
+BufferAssignment::GetUniqueTopLevelOutputAllocation() const {
+ return GetUniqueTopLevelAllocation(
+ module_->entry_computation()->root_instruction());
+}
+
+BufferAllocation* BufferAssignment::NewAllocation(const LogicalBuffer& buffer,
+ int64 size,
+ bool is_thread_local) {
+ BufferAllocation::Index index = allocations_.size();
+ allocations_.emplace_back(index, size, is_thread_local);
+ BufferAllocation* allocation = &allocations_.back();
+ AddAssignment(buffer, allocation);
+ allocation_index_for_buffer_[&buffer] = index;
+ return allocation;
+}
+
+// Adds an instruction to the set assigned to the given buffer.
+void BufferAssignment::AddAssignment(const LogicalBuffer& buffer,
+ BufferAllocation* allocation) {
+ CHECK_EQ(0, allocation_index_for_buffer_.count(&buffer))
+ << "LogicalBuffer " << buffer << " already has an allocation.";
+ TF_CHECK_OK(points_to_analysis().VerifyBuffer(buffer));
+
+ allocation->AddAssignment(buffer);
+ allocation_index_for_buffer_[&buffer] = allocation->index();
+}
+
+string BufferAssignment::ToString() const {
+ string output;
+ tensorflow::strings::StrAppend(&output, "BufferAssignment:\n");
+ for (auto& allocation : allocations_) {
+ tensorflow::strings::StrAppend(&output, allocation.ToString());
+ }
+ return output;
+}
+
+namespace {
+
+// Walk the call graph of the HLO module and place each computation into either
+// thread_local_computations or global_computations depending upon whether the
+// computation requires thread-local allocations or global allocations. The
+// elements in thread_local_computations and global_computations are in post
+// order (if computation A has an instruction which calls computation B, then A
+// will appear after B in the vector).
+tensorflow::Status GatherComputationsByAllocationType(
+ const HloModule* module,
+ std::vector<const HloComputation*>* thread_local_computations,
+ std::vector<const HloComputation*>* global_computations) {
+ // Create a worklist of computations paired with whether the allocation must
+ // be thread-local.
+ std::deque<std::pair<HloComputation*, bool>> worklist;
+ worklist.push_back(std::make_pair(module->entry_computation(),
+ /*is_thread_local*/ false));
+
+ // Sets for quickly checking membership. Computations are returned in vectors
+ // for stable iteration.
+ std::unordered_set<HloComputation*> thread_local_set;
+ std::unordered_set<HloComputation*> global_set;
+
+ while (!worklist.empty()) {
+ auto worklist_front = worklist.front();
+ worklist.pop_front();
+ HloComputation* computation = worklist_front.first;
+ bool is_thread_local = worklist_front.second;
+ bool in_thread_local_set = thread_local_set.count(computation) > 0;
+ bool in_global_set = global_set.count(computation) > 0;
+
+ // If the computation has already been added to the respective set, then
+ // nothing to do.
+ if ((is_thread_local && in_thread_local_set) ||
+ (!is_thread_local && in_global_set)) {
+ continue;
+ }
+
+ // If the computation has already been added to the other set this is an
+ // error condition because the global call to the computation (eg,
+ // while/call) may return a reference to one of the thread-local buffers to
+ // the calling computation which will become a dangling reference when the
+ // thread-local is deallocated with the call return.
+ if ((is_thread_local && in_global_set) ||
+ (!is_thread_local && in_thread_local_set)) {
+ return InvalidArgument(
+ "computation %s has conflicting allocation requirements (global "
+ "and thread-local)",
+ computation->name().c_str());
+ }
+
+ if (is_thread_local) {
+ thread_local_set.insert(computation);
+ } else {
+ global_set.insert(computation);
+ }
+
+ for (auto& instruction : computation->instructions()) {
+ for (auto* subcomputation : instruction->MakeCalledComputationsSet()) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kCall:
+ case HloOpcode::kWhile:
+ // Call and while must be called from a computation with global
+ // allocations as they may return references to buffers inside the
+ // called computation which cannot be thread-local.
+ if (is_thread_local) {
+ return InvalidArgument(
+ "computation %s cannot contain call/while op because it "
+ "requires thread-local buffer allocations",
+ computation->name().c_str());
+ }
+ worklist.push_back(std::make_pair(subcomputation,
+ false)); // Not thread local.
+ break;
+ case HloOpcode::kMap:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kFusion:
+ // Map/reduce etc computations are always thread-local.
+ worklist.push_back(std::make_pair(subcomputation,
+ true)); // Thread local.
+ break;
+ default:
+ return InternalError(
+ "Unexpected calling opcode: %s",
+ HloOpcodeString(instruction->opcode()).c_str());
+ }
+ }
+ }
+ }
+
+ // Add the computations to the vectors in post order.
+ for (auto* computation : module->MakeComputationPostOrder()) {
+ if (thread_local_set.count(computation) > 0) {
+ thread_local_computations->push_back(computation);
+ } else if (global_set.count(computation) > 0) {
+ global_computations->push_back(computation);
+ }
+ // If the computation is not reachable from the entry computation, then it
+ // will not appear in either thread_local_set or global_set. We don't bother
+ // assigning buffers for these.
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+/* static */
+StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
+ BufferSizeFunction buffer_size, bool colocate_related_buffers,
+ const std::vector<const HloInstruction*>* hlos_to_allocate) {
+ BufferAssigner assigner(std::move(buffer_size), colocate_related_buffers);
+ return assigner.CreateAssignment(module, std::move(hlo_ordering),
+ hlos_to_allocate);
+}
+
+/* static */
+StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
+ int64 pointer_size) {
+ return BufferAssigner::Run(module, std::move(hlo_ordering),
+ [pointer_size](const LogicalBuffer& buffer) {
+ return ShapeUtil::IsOpaque(buffer.shape())
+ ? 0
+ : ShapeUtil::ByteSizeOf(
+ buffer.shape(), pointer_size);
+ },
+ /*colocate_related_buffers=*/true);
+}
+
+bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
+ const LogicalBuffer& buffer,
+ BufferAssignment* assignment) {
+ CHECK(!assignment->HasAllocation(buffer))
+ << "buffer " << buffer << " already has an allocation assigned.";
+
+ VLOG(4) << "Trying to assign " << buffer.ToString()
+ << " to allocation: " << allocation->ToString();
+
+ if (buffer_size_(buffer) > allocation->size()) {
+ VLOG(4) << "Can't assign: buffer is larger than allocation ("
+ << buffer_size_(buffer) << " > " << allocation->size() << ")";
+ return false;
+ }
+
+ if (allocation->is_entry_computation_parameter()) {
+ VLOG(4) << "Can't assign: allocation holds parameter";
+ return false;
+ }
+
+ for (const LogicalBuffer* assigned_buffer : allocation->assigned_buffers()) {
+ if (assignment->liveness().MayInterfere(*assigned_buffer, buffer)) {
+ VLOG(4) << "Can't assign: assignee " << assigned_buffer->ToString()
+ << " may interfere with " << buffer.ToString();
+ return false;
+ }
+ }
+
+ // If the buffer is live out of the computation then it should only be
+ // assigned a buffer which exactly fits the result to avoid wasting memory
+ // (result buffers can have arbitrary lifetimes).
+ if (assignment->liveness().MaybeLiveOut(buffer) &&
+ allocation->size() != buffer_size_(buffer)) {
+ VLOG(4) << "Can't assign: buffer " << buffer.ToString()
+ << "is live out and size not the same as allocation";
+ return false;
+ }
+
+ assignment->AddAssignment(buffer, allocation);
+ return true;
+}
+
+tensorflow::Status BufferAssigner::AssignBuffersForComputation(
+ const HloComputation* computation, bool is_thread_local,
+ const std::unordered_set<const HloInstruction*>* hlos_to_allocate,
+ BufferAssignment* assignment) {
+ // Buffers are sorted and assigned to BufferAllocations in decreasing order of
+ // size.
+ std::vector<const LogicalBuffer*> sorted_buffers;
+ for (auto& instruction : computation->instructions()) {
+ if (hlos_to_allocate == nullptr ||
+ hlos_to_allocate->count(instruction.get()) > 0) {
+ // Add all buffers which this instruction defines. Instruction which don't
+ // define buffers (eg, bitcast which just forwards a pointer) don't need
+ // any allocations.
+ for (const LogicalBuffer* buffer :
+ assignment->points_to_analysis().GetBuffersDefinedByInstruction(
+ instruction.get())) {
+ sorted_buffers.push_back(buffer);
+ }
+ }
+ }
+
+ // Generate a post order sort of instructions for sorting of the
+ // LogicalBuffers.
+ std::unordered_map<const HloInstruction*, int> post_order_position;
+ int position = 0;
+ for (auto* instruction : computation->MakeInstructionPostOrder()) {
+ post_order_position.emplace(instruction, position);
+ position++;
+ }
+
+ // Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
+ // first for simplicity. This means any previously created BufferAllocation is
+ // necessarily large enough to hold the output of the current Buffer in
+ // consideration.
+ //
+ // As a secondary sorting criteria, use post order position of the HLO
+ // instruction which defines the buffer. This means an instruction will appear
+ // after its operands (assuming operands are the same/larger size) enabling
+ // the important reuse case where an elementwise instruction reuses one of its
+ // operand's buffer. This improves locality.
+ std::sort(sorted_buffers.begin(), sorted_buffers.end(),
+ [this, &post_order_position](const LogicalBuffer* a,
+ const LogicalBuffer* b) {
+ int64 a_size = buffer_size_(*a);
+ int64 b_size = buffer_size_(*b);
+ if (a_size == b_size) {
+ // For instructions with the same size buffers, sort them in
+ // post order.
+ return post_order_position.at(a->instruction()) <
+ post_order_position.at(b->instruction());
+ } else {
+ // We want the HLOs sorted in reverse order by size so use ">".
+ return a_size > b_size;
+ }
+ });
+
+ // BufferAllocations are necessarily created in decreasing size order. Keep
+ // indices of previously created BufferAllocations in allocation_indices.
+ std::vector<BufferAllocation::Index> allocation_indices;
+ for (const auto* buffer : sorted_buffers) {
+ VLOG(3) << "Assigning allocation to: " << buffer->ToString();
+ if (colocated_buffers_.find(buffer) != colocated_buffers_.end()) {
+ // Colocated buffers are currently assigned in an earlier pass.
+ continue;
+ }
+
+ TF_RET_CHECK(!assignment->HasAllocation(*buffer));
+
+ if (buffer->instruction()->opcode() == HloOpcode::kConstant) {
+ // No BufferAllocations for constants.
+ // TODO(b/32248867): For consistency, constants should get allocations.
+ continue;
+ }
+
+ if (buffer->instruction()->opcode() == HloOpcode::kParameter &&
+ computation == computation->parent()->entry_computation()) {
+ // If the LogicalBuffer is part of an external parameter, creates a new
+ // allocation and sets its parameter number. Parameters of non-entry
+ // computations do not need special allocations because they live inside
+ // callers.
+ BufferAllocation* allocation =
+ assignment->NewAllocation(*buffer, buffer_size_(*buffer),
+ /*is_thread_local=*/false);
+ allocation->set_entry_computation_parameter(
+ buffer->instruction()->parameter_number());
+ VLOG(3) << "New allocation for entry computation parameter: "
+ << buffer->ToString();
+ continue;
+ }
+
+ legacy_flags::BufferAssignmentFlags* flags =
+ legacy_flags::GetBufferAssignmentFlags();
+ if (!flags->xla_enable_buffer_reuse || is_thread_local ||
+ buffer->instruction()->opcode() == HloOpcode::kCustomCall) {
+ // Custom call operations never have reusable buffers. Also we do not
+ // reuse thread-local buffers for now, because they are dynamically
+ // allocated and their lifetimes are hard to compute.
+ assignment->NewAllocation(*buffer, buffer_size_(*buffer),
+ is_thread_local);
+ continue;
+ }
+
+ if (buffer->instruction()->opcode() == HloOpcode::kCall &&
+ buffer->IsTopLevel()) {
+ // Assign the kCall instruction the same allocation as the root of the
+ // called computation. The points-to set of the root of the called
+ // computation must be unambigous so we know statically the allocation for
+ // the root.
+ //
+ // TODO(b/32491382): This is a hack. To properly handle this case
+ // points-to analysis, liveness analysis, and buffer assignment need to
+ // module-scope rather than computation-scope.
+ HloInstruction* call = buffer->instruction();
+ HloInstruction* computation_root = call->to_apply()->root_instruction();
+
+ // The buffer of the root of the called computation must be unambiguous.
+ const auto& root_points_to = assignment->GetPointsToSet(computation_root);
+ if (root_points_to.IsAmbiguous()) {
+ return Unimplemented(
+ "kCall of a computation with an ambiguous root points-to set");
+ }
+ CHECK_EQ(1, root_points_to.element(/*index=*/{}).size());
+ const LogicalBuffer* root_buffer =
+ root_points_to.element(/*index=*/{})[0];
+ BufferAllocation* root_allocation =
+ assignment->GetMutableAssignedAllocation(*root_buffer);
+
+ // Can't use MaybeAssignBuffer here because buffer liveness conservatively
+ // assumes buffers in different computations always interfere.
+ CHECK_GE(root_allocation->size(), buffer_size_(*buffer));
+ assignment->AddAssignment(*buffer, root_allocation);
+ continue;
+ }
+
+ // First try to assign a LogicalBuffer to one of its operand allocations to
+ // improve locality. This is only possible with elementwise operations
+ // (checked in liveness analysis) which are necessarily top-level
+ // array-shaped buffers.
+ if (buffer->IsTopLevel() && !buffer->IsTuple()) {
+ for (auto* operand : buffer->instruction()->operands()) {
+ bool assigned_operand = false;
+ for (const auto& operand_allocation :
+ assignment->GetAllocations(operand, /*index=*/{})) {
+ BufferAllocation* allocation =
+ assignment->GetMutableAllocation(operand_allocation.index());
+ if (colocated_buffer_allocations_.find(allocation->index()) ==
+ colocated_buffer_allocations_.end()) {
+ // TODO(b/32491382) Colocated buffers are currently assigned in an
+ // earlier pass, and so can break the "increasing allocation size"
+ // invariant in this function (causing this CHECK to fail). However,
+ // the call to MaybeAssignBuffer is safe as it returns false if
+ // allocation.size < buffer.size.
+ CHECK_GE(allocation->size(), buffer_size_(*buffer));
+ }
+ if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
+ VLOG(3) << "Reusing (operand) allocation for: "
+ << buffer->ToString();
+ assigned_operand = true;
+ break;
+ }
+ }
+ if (assigned_operand) {
+ break;
+ }
+ }
+ }
+
+ if (!assignment->HasAllocation(*buffer)) {
+ // Find the smallest buffer which can be reused iterating from end of
+ // allocation_indices (smallest) to beginning (largest).
+ for (int allocation_index = allocation_indices.size() - 1;
+ allocation_index >= 0; allocation_index--) {
+ BufferAllocation* allocation = assignment->GetMutableAllocation(
+ allocation_indices[allocation_index]);
+ // Instructions are iterated in increasing buffer size, so any
+ // previously create allocation must be large enough to hold this
+ // instruction's output (with the exception of colocated buffers).
+ if (colocated_buffer_allocations_.find(allocation->index()) ==
+ colocated_buffer_allocations_.end()) {
+ // TODO(b/32491382) Colocated buffers are currently assigned in an
+ // earlier pass, and so can break the "increasing allocation size"
+ // invariant in this function (causing this CHECK to fail). However,
+ // the call to MaybeAssignBuffer is safe as it returns false if
+ // allocation.size < buffer.size.
+ CHECK_GE(allocation->size(), buffer_size_(*buffer));
+ }
+
+ if (MaybeAssignBuffer(allocation, *buffer, assignment)) {
+ VLOG(3) << "Reusing buffer for: " << buffer->ToString();
+ break;
+ }
+ }
+ }
+ if (!assignment->HasAllocation(*buffer)) {
+ auto* allocation = assignment->NewAllocation(
+ *buffer, buffer_size_(*buffer), is_thread_local);
+ VLOG(3) << "New allocation for: " << buffer->ToString();
+ allocation_indices.push_back(allocation->index());
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+void BufferAssigner::AddBufferToColocatedBufferSet(
+ const HloInstruction* instruction, const ShapeIndex& index,
+ const TuplePointsToAnalysis& points_to_analysis,
+ BufferAssigner::ColocatedBufferSet* colocated_buffer_set) {
+ const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
+ // CopyInsertion ensures root points-to set is unambiguous and distinct.
+ CHECK(!points_to.IsAmbiguous());
+ CHECK(points_to.IsDistinct());
+ colocated_buffer_set->push_back(points_to.element(index)[0]);
+}
+
+// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
+// in the same allocation (currently just supports kWhile).
+std::vector<BufferAssigner::ColocatedBufferSet>
+BufferAssigner::BuildColocatedBufferSets(
+ const HloModule* module, const TuplePointsToAnalysis& points_to_analysis) {
+ std::vector<ColocatedBufferSet> colocated_buffer_sets;
+ for (auto& computation : module->computations()) {
+ for (auto& instruction : computation->instructions()) {
+ if (instruction->opcode() != HloOpcode::kWhile) {
+ continue;
+ }
+ HloInstruction* while_hlo = instruction.get();
+ TF_CHECK_OK(ShapeUtil::ForEachSubshape(
+ while_hlo->shape(),
+ [this, &points_to_analysis, &while_hlo, &colocated_buffer_sets](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ ColocatedBufferSet colocated_buffer_set;
+ // Add while.init.
+ AddBufferToColocatedBufferSet(while_hlo->operand(0), index,
+ points_to_analysis,
+ &colocated_buffer_set);
+ // Add while.result.
+ AddBufferToColocatedBufferSet(while_hlo, index, points_to_analysis,
+ &colocated_buffer_set);
+ // Add while.cond.parameter.
+ AddBufferToColocatedBufferSet(
+ while_hlo->while_condition()->parameter_instruction(0), index,
+ points_to_analysis, &colocated_buffer_set);
+ // Add while.body.parameter.
+ AddBufferToColocatedBufferSet(
+ while_hlo->while_body()->parameter_instruction(0), index,
+ points_to_analysis, &colocated_buffer_set);
+ // Add while.body.root.
+ AddBufferToColocatedBufferSet(
+ while_hlo->while_body()->root_instruction(), index,
+ points_to_analysis, &colocated_buffer_set);
+
+ colocated_buffer_sets.push_back(std::move(colocated_buffer_set));
+ return tensorflow::Status::OK();
+ }));
+ }
+ }
+ return colocated_buffer_sets;
+}
+
+// Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
+// allocation in 'assignment'.
+void BufferAssigner::AssignColocatedBufferSets(
+ const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
+ BufferAssignment* assignment) {
+ for (const auto& colocated_buffer_set : colocated_buffer_sets) {
+ BufferAllocation* allocation = nullptr;
+ for (const auto& buffer : colocated_buffer_set) {
+ if (colocated_buffers_.find(buffer) != colocated_buffers_.end()) {
+ // ColocatedBufferSet duplicates can occur if a buffer is forwarded
+ // from one instruction to another (i.e. while.body param to root).
+ continue;
+ }
+ if (allocation == nullptr) {
+ // TODO(b/32491382) Avoid current trivial solution of using new
+ // allocations for each colocated buffer set. When liveness has
+ // module-level scope, we can allow buffers to be shared across
+ // computations (in some cases).
+ allocation = assignment->NewAllocation(*buffer, buffer_size_(*buffer),
+ /*is_thread_local=*/false);
+ colocated_buffer_allocations_.insert(allocation->index());
+ } else {
+ assignment->AddAssignment(*buffer, allocation);
+ }
+ colocated_buffers_.insert(buffer);
+ }
+ }
+}
+
+StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
+ const std::vector<const HloInstruction*>* hlos_to_allocate) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
+ BufferLiveness::Run(module, std::move(hlo_ordering)));
+
+ std::vector<const HloComputation*> thread_local_computations;
+ std::vector<const HloComputation*> global_computations;
+ VLOG(1) << "Assigning buffers to module " << module->name();
+ if (hlos_to_allocate != nullptr) {
+ VLOG(3) << "LogicalBuffer assignment restricted to hlos: ";
+ for (auto hlo : *hlos_to_allocate) {
+ VLOG(3) << " " << hlo->parent()->name() << "::" << hlo->name();
+ }
+ }
+ XLA_VLOG_LINES(3, module->ToString());
+ XLA_VLOG_LINES(3, liveness->ToString());
+ XLA_VLOG_LINES(3, liveness->points_to_analysis().ToString());
+
+ TF_RETURN_IF_ERROR(GatherComputationsByAllocationType(
+ module, &thread_local_computations, &global_computations));
+
+ // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to
+ // AssignBuffersForComputation for fast membership testing.
+ std::unique_ptr<std::unordered_set<const HloInstruction*>> hlo_set;
+ if (hlos_to_allocate != nullptr) {
+ hlo_set = MakeUnique<std::unordered_set<const HloInstruction*>>(
+ hlos_to_allocate->begin(), hlos_to_allocate->end());
+ }
+
+ // Can't use MakeUnique because BufferAssignment constructor is private.
+ std::unique_ptr<BufferAssignment> assignment(
+ new BufferAssignment(module, std::move(liveness)));
+
+ // Assign buffers with the tightest constraints first (colocated buffer sets).
+ // Once b/32491382 enables module-level liveness analysis, we may be able
+ // to assign colocated buffers (or at least reuse their allocation for
+ // buffers outside of the set) in AssignBuffersForComputation.
+ if (colocate_related_buffers_) {
+ std::vector<ColocatedBufferSet> colocated_buffer_sets =
+ BuildColocatedBufferSets(module, assignment->points_to_analysis());
+ AssignColocatedBufferSets(colocated_buffer_sets, assignment.get());
+ }
+
+ for (auto* computation : global_computations) {
+ TF_RETURN_IF_ERROR(AssignBuffersForComputation(
+ computation,
+ /*is_thread_local=*/false, hlo_set.get(), assignment.get()));
+ }
+ for (auto* computation : thread_local_computations) {
+ TF_RET_CHECK(computation != module->entry_computation());
+ TF_RETURN_IF_ERROR(AssignBuffersForComputation(
+ computation,
+ /*is_thread_local=*/true, hlo_set.get(), assignment.get()));
+ }
+
+ // Mark all buffers which may be live out of the entry computation as
+ // "liveout".
+ auto entry = module->entry_computation();
+ auto root_instruction = entry->root_instruction();
+ const PointsToSet& root_points_to =
+ assignment->GetPointsToSet(root_instruction);
+ TF_RETURN_IF_ERROR(root_points_to.ForEachElement([&assignment](
+ const ShapeIndex& /*index*/, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& buffers) {
+ for (auto buffer : buffers) {
+ if (assignment->HasAllocation(*buffer)) {
+ assignment->GetMutableAssignedAllocation(*buffer)->set_maybe_live_out(
+ true);
+ }
+ }
+ return tensorflow::Status::OK();
+ }));
+
+ XLA_VLOG_LINES(2, assignment->ToString());
+
+ // Compute sizes of various kinds of buffers for logging.
+ int64 total_size = 0;
+ int64 parameter_size = 0;
+ for (auto& allocation : assignment->Allocations()) {
+ if (allocation.is_entry_computation_parameter()) {
+ parameter_size += allocation.size();
+ }
+ total_size += allocation.size();
+ }
+
+ // Compute the total size of the output. Iterate over the subshapes and sum up
+ // the sizes of the buffers for each subshape.
+ int64 output_size = 0;
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape(
+ root->shape(), [this, &output_size, root, &assignment](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ const auto& allocations = assignment->GetAllocations(root, index);
+ if (allocations.size() > 0) {
+ output_size += allocations.begin()->size();
+ }
+ return tensorflow::Status::OK();
+ }));
+
+ VLOG(1) << "Allocation sizes for module " << module->name() << ":";
+ VLOG(1) << " parameter allocation total size: "
+ << tensorflow::strings::HumanReadableNumBytes(parameter_size);
+ VLOG(1) << " output allocation total size: "
+ << tensorflow::strings::HumanReadableNumBytes(output_size);
+ VLOG(1) << " temp allocation total size: "
+ << tensorflow::strings::HumanReadableNumBytes(
+ total_size - parameter_size - output_size);
+ VLOG(1) << " total allocation size: "
+ << tensorflow::strings::HumanReadableNumBytes(total_size);
+ return std::move(assignment);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
new file mode 100644
index 0000000000..af455de298
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -0,0 +1,358 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
+
+#include <functional>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// This class abstracts an allocation of contiguous memory which can hold the
+// values described by LogicalBuffers. A BufferAllocation may hold different
+// LogicalBuffers at different times, but currently never more than one
+// LogicalBuffer simultaneously. The abstraction includes information required
+// by the backends for allocation, use, and deallocation of the buffer. This
+// includes the LogicalBuffers which are held in this allocation through the
+// execution of the computation.
+class BufferAllocation {
+ public:
+ // Holds a unique identifier for each allocation. Values are assigned
+ // contiguously and can be used as array indexes.
+ using Index = int64;
+
+ BufferAllocation(Index index, int64 size, bool is_thread_local)
+ : index_(index), size_(size), is_thread_local_(is_thread_local) {}
+ ~BufferAllocation() {}
+
+ // Adds a LogicalBuffer to the set assigned to this buffer.
+ void AddAssignment(const LogicalBuffer& buffer);
+
+ // Whether this allocation is used in a parallel calling context such as
+ // inside of a map or reduce computation. Such allocations need to be thread
+ // local.
+ bool is_thread_local() const { return is_thread_local_; }
+
+ // Whether this allocation holds a LogicalBuffer from a parameter of the entry
+ // computation. These buffers have lifetimes which may be longer than the
+ // XLA computation.
+ bool is_entry_computation_parameter() const {
+ return is_entry_computation_parameter_;
+ }
+ // If this allocation holds a Buffer from a parameter of the entry
+ // computation, this methods returns the parameter number. CHECKs otherwise.
+ int64 parameter_number() const {
+ CHECK(is_entry_computation_parameter_);
+ return parameter_number_;
+ }
+ // Sets that this allocation holds a LogicalBuffer from a parameter of the
+ // entry computation.
+ void set_entry_computation_parameter(int64 parameter_number) {
+ is_entry_computation_parameter_ = true;
+ parameter_number_ = parameter_number;
+ }
+
+ // Returns/sets whether this allocation is assigned a LogicalBuffer which may
+ // be live out of the entry computation.
+ bool maybe_live_out() const { return maybe_live_out_; }
+ void set_maybe_live_out(bool value) { maybe_live_out_ = value; }
+
+ // Returns the size of the allocation. Necessarily this must be at least as
+ // large as any LogicalBuffer assigned to this allocation.
+ int64 size() const { return size_; }
+
+ // Access to the logical buffers assigned to this allocation.
+ const std::vector<const LogicalBuffer*>& assigned_buffers() const {
+ return assigned_buffers_;
+ }
+
+ Index index() const { return index_; }
+
+ string ToString() const;
+
+ // Whether the buffer is a parameter to or live out of the entry computation.
+ bool IsInputOrOutput() const {
+ return is_entry_computation_parameter() || maybe_live_out();
+ }
+
+ // Whether the buffer is a temporary buffer allocated before
+ // Executable::ExecuteOnStream.
+ bool IsPreallocatedTempBuffer() const {
+ // Parameters do not need temporary buffers.
+ return !is_entry_computation_parameter() &&
+ // LogicalBuffers that maybe pointed to by the output should live out
+ // of the computation.
+ !maybe_live_out() &&
+ // Thread-local buffers are allocated using `alloca`s.
+ !is_thread_local();
+ }
+
+ bool operator==(const BufferAllocation& other) const {
+ return index_ == other.index_;
+ }
+ bool operator!=(const BufferAllocation& other) const {
+ return !(*this == other);
+ }
+ bool operator<(const BufferAllocation& other) const {
+ return index() < other.index();
+ }
+
+ private:
+ // The index of the allocation in the BufferAssignment.
+ Index index_;
+
+ // Size of the allocation in bytes.
+ int64 size_;
+
+ // Whether this buffer needs to be thread-local.
+ bool is_thread_local_;
+
+ // Whether this allocation holds an entry computation parameter. Entry
+ // computation parameters are special be cause they have lifetimes which may
+ // outlast the computation.
+ bool is_entry_computation_parameter_ = false;
+
+ // If this allocation holds an entry computation parameter, this field
+ // indicates the index (starting from 0) of the parameter.
+ int64 parameter_number_ = 0;
+
+ // Whether the allocation contains a LogicalBuffer which may be live-out of
+ // the entry computation. Note that this flag is conservatively computed by
+ // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_`
+ // might not actually escape.
+ bool maybe_live_out_ = false;
+
+ // The set of buffers assigned to this allocation.
+ std::vector<const LogicalBuffer*> assigned_buffers_;
+};
+
+// Add stream operator for nicer output of CHECK/RET_CHECK failures.
+std::ostream& operator<<(std::ostream& out, const BufferAllocation& s);
+
+// This class encapsulates an assignment of the LogicalBuffers in an XLA
+// module to a set of BufferAllocations.
+class BufferAssignment {
+ public:
+ // Returns the vector containing all buffer allocations in this assignment.
+ const std::vector<BufferAllocation>& Allocations() const {
+ return allocations_;
+ }
+
+ // Returns whether the given buffer has been assigned an allocation.
+ bool HasAllocation(const LogicalBuffer& buffer) const;
+
+ // Returns the allocation that a particular LogicalBuffer has been assigned
+ // to. CHECKs if buffer has not been assigned an allocation.
+ const BufferAllocation& GetAssignedAllocation(
+ const LogicalBuffer& buffer) const;
+
+ // Returns the allocation with the given index. CHECKs if no allocation exists
+ // with the given index.
+ const BufferAllocation& GetAllocation(BufferAllocation::Index index) const;
+
+ // Builds and returns a vector containing the allocations which might contain
+ // the subvalue at the given index of given instruction.
+ std::set<BufferAllocation> GetAllocations(const HloInstruction* instruction,
+ const ShapeIndex& index) const;
+
+ // Convenience function which returns whether the top-level buffer of the
+ // instruction (index == {}) is assigned an allocation.
+ bool HasTopLevelAllocation(const HloInstruction* instruction) const;
+
+ // Convenience function which returns the unique buffer allocation containing
+ // the buffer at the given index of the given instruction. If an allocation is
+ // not assigned or the allocation cannot be determined at compile time then an
+ // error is returned.
+ StatusOr<const BufferAllocation*> GetUniqueAllocation(
+ const HloInstruction* instruction, const ShapeIndex& index) const;
+ // Like GetUniqueAllocation but fixes the index to the top-level of the shape
+ // (index = {}).
+ StatusOr<const BufferAllocation*> GetUniqueTopLevelAllocation(
+ const HloInstruction* instruction) const;
+ // Like GetUniqueTopLevelAllocation but returns the allocation for the output
+ // of the entry computation of the HLO module (ie, the result of the XLA
+ // computation).
+ StatusOr<const BufferAllocation*> GetUniqueTopLevelOutputAllocation() const;
+
+ // Returns the set LogicalBuffers which may be the source of the value at the
+ // given index and instruction.
+ const std::vector<const LogicalBuffer*>& GetSourceBuffers(
+ const HloInstruction* instruction, const ShapeIndex& index) const {
+ return GetPointsToSet(instruction).element(index);
+ }
+
+ // Returns the underlying points-to analysis used for this assignment.
+ const TuplePointsToAnalysis& points_to_analysis() const {
+ return liveness_->points_to_analysis();
+ }
+
+ string ToString() const;
+
+ private:
+ // Only BufferAssigner can build or modify BufferAssignments.
+ friend class BufferAssigner;
+
+ explicit BufferAssignment(const HloModule* module,
+ std::unique_ptr<BufferLiveness> liveness)
+ : module_(module), liveness_(std::move(liveness)) {}
+
+ // Creates and returns a new BufferAllocation. Ownership is maintained
+ // internally. The allocation initially has only the given LogicalBuffer
+ // assigned to it. `is_thread_local` indicates whether this buffer needs to be
+ // thread-local.
+ BufferAllocation* NewAllocation(const LogicalBuffer& buffer, int64 size,
+ bool is_thread_local);
+
+ // Adds a LogicalBuffer to the set assigned to the given allocation.
+ void AddAssignment(const LogicalBuffer& buffer, BufferAllocation* allocation);
+
+ // Returns the BufferLiveness object used to construct this assignment.
+ const BufferLiveness& liveness() { return *liveness_; }
+
+ // Convenience function which returns the PointsToSet for the given
+ // instruction. Extracted from the liveness object.
+ const PointsToSet& GetPointsToSet(const HloInstruction* instruction) const;
+
+ // Mutable accessors for allocations.
+ BufferAllocation* GetMutableAssignedAllocation(const LogicalBuffer& buffer);
+ BufferAllocation* GetMutableAllocation(BufferAllocation::Index index);
+
+ // The vector of buffer allocations. Indexed by BufferAllocation::Index.
+ std::vector<BufferAllocation> allocations_;
+
+ // Maps Buffers to the index of the BufferAllocation which holds the buffer.
+ std::map<const LogicalBuffer*, BufferAllocation::Index>
+ allocation_index_for_buffer_;
+
+ const HloModule* module_;
+ std::unique_ptr<BufferLiveness> liveness_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BufferAssignment);
+};
+
+// A class which constructs a buffer assignment.
+class BufferAssigner {
+ public:
+ // Build and return a BufferAssignment for the given module. The given
+ // HloOrdering is used to determine buffer liveness. buffer_size is a function
+ // which returns the size of a LogicalBuffer. If hlos_to_allocate is not null
+ // then only instructions in this vector are considered for buffer
+ // assignment. If hlos_to_allocate is null then all instructions are
+ // considered. If 'colocate_related_buffers' is true, related LogicalBuffers
+ // will be colocated in the same allocation (i.e buffers for while result
+ // will share an allocation with buffers related to that same while
+ // instruction: init operand, condition/body parameter and body result).
+ using BufferSizeFunction = std::function<int64(const LogicalBuffer&)>;
+ static StatusOr<std::unique_ptr<BufferAssignment>> Run(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
+ BufferSizeFunction buffer_size, bool colocate_related_buffers,
+ const std::vector<const HloInstruction*>* hlos_to_allocate = nullptr);
+
+ // Overload of Run which uses ShapeUtil::ByteSizeOf to determine buffer size
+ // and assigns buffers to all HLO instructions in the module.
+ static StatusOr<std::unique_ptr<BufferAssignment>> Run(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
+ int64 pointer_size);
+
+ private:
+ explicit BufferAssigner(BufferSizeFunction buffer_size,
+ bool colocate_related_buffers)
+ : buffer_size_(std::move(buffer_size)),
+ colocate_related_buffers_(colocate_related_buffers) {}
+ virtual ~BufferAssigner() = default;
+
+ // Create a buffer assignment.
+ StatusOr<std::unique_ptr<BufferAssignment>> CreateAssignment(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
+ const std::vector<const HloInstruction*>* hlos_to_allocate = nullptr);
+
+ // Assigns buffers to the instructions in the given computation. "assignment"
+ // is modified to reflect the new buffer assignments. If is_thread_local is
+ // true, then all assigned buffers have the is_thread_local flag set to
+ // true. If hlos_to_allocate is not null it indicates which HLOs to include in
+ // buffer assignment. If null, all instructions in the computation are
+ // included.
+ tensorflow::Status AssignBuffersForComputation(
+ const HloComputation* computation, bool is_thread_local,
+ const std::unordered_set<const HloInstruction*>* hlos_to_allocate,
+ BufferAssignment* assignment);
+
+ // Tries to assign the given instruction to the given buffer. Returns if the
+ // assignment was successful.
+ bool MaybeAssignBuffer(BufferAllocation* allocation,
+ const LogicalBuffer& buffer,
+ BufferAssignment* assignment);
+
+ using ColocatedBufferSet = std::vector<const LogicalBuffer*>;
+
+ // Returns a vector of ColocatedBufferSet objects, where each
+ // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
+ // which should be colocated in the same buffer allocation.
+ std::vector<ColocatedBufferSet> BuildColocatedBufferSets(
+ const HloModule* module, const TuplePointsToAnalysis& points_to_analysis);
+
+ // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
+ // same set to the same buffer allocation in 'assignment'.
+ void AssignColocatedBufferSets(
+ const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
+ BufferAssignment* assignment);
+
+ // Checks that points-to set of 'instruction' is unambiguous and distinct
+ // (ensured by CopyInsertion), then adds buffer from point-to set at 'index'
+ // to 'colocated_buffer_set'.
+ void AddBufferToColocatedBufferSet(
+ const HloInstruction* instruction, const ShapeIndex& index,
+ const TuplePointsToAnalysis& points_to_analysis,
+ BufferAssigner::ColocatedBufferSet* colocated_buffer_set);
+
+ const HloModule* module_;
+
+ // Function which returns the buffer size for a given shape.
+ BufferSizeFunction buffer_size_;
+
+ // Indicates whether related buffers should share the same buffer allocation.
+ const bool colocate_related_buffers_;
+
+ // Set of colocated buffers populated in AssignColocatedBufferSets.
+ std::unordered_set<const LogicalBuffer*> colocated_buffers_;
+
+ // Set of allocations containing colocated buffers.
+ std::unordered_set<BufferAllocation::Index> colocated_buffer_allocations_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
new file mode 100644
index 0000000000..56138a7ee6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -0,0 +1,1051 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+namespace {
+
+// DFS visitor that collects the instructions referenced by a computation
+// without descending into nested computations, i.e., only from the operands.
+class InstructionListVisitor : public DfsHloVisitorWithDefault {
+ public:
+ explicit InstructionListVisitor(const HloInstruction* root) : root_(root) {}
+
+ Status DefaultAction(HloInstruction* hlo) override {
+ // For each instruction, just push it on the list after walking the
+ // operands.
+ instructions_.push_back(hlo);
+ VLOG(0) << "List instruction " << hlo->ToString();
+ return Status::OK();
+ }
+
+ std::vector<const HloInstruction*> GetInstructions() { return instructions_; }
+
+ private:
+ // The instruction root of the computation.
+ const HloInstruction* root_;
+
+ // The full set of instructions found (may be duplicates, e.g., kParameter).
+ std::vector<const HloInstruction*> instructions_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InstructionListVisitor);
+};
+
+const std::vector<const HloInstruction*> GetInstructions(HloInstruction* root) {
+ InstructionListVisitor main_list(root);
+ TF_CHECK_OK(root->Accept(&main_list));
+ return main_list.GetInstructions();
+}
+
+class BufferAssignmentTest : public HloTestBase {
+ protected:
+ BufferAssignmentTest() : computation_tracker_() {}
+ ~BufferAssignmentTest() override {}
+
+ // Builds an x+1.0 computation to use in a Map.
+ std::unique_ptr<HloComputation> BuildMapComputationPlus1(const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ auto value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value));
+ return builder.Build();
+ }
+
+ // Builds a simple compare-to-limit (x < 4) computation for a While.
+ //
+ // condition:
+ // const4[s32] -----------------------------------\
+ // \
+ // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
+ //
+ std::unique_ptr<HloComputation> BuildWhileConditionComputation(
+ const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto const4 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
+ auto index = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kLt, index, const4));
+ return builder.Build();
+ }
+
+ // Builds a simple body computation for a While.
+ //
+ // body:
+ // constv[f32[4]] --------------------------------------\
+ // \
+ // /--- get-tuple-elementv[1] --- addv ---\
+ // param[(s32,f32[4])] ---| tuple
+ // \--- get-tuple-elementc[0] --- addc ---/
+ // /
+ // const1[s32] -----------------------------------------/
+ //
+ std::unique_ptr<HloComputation> BuildWhileBodyComputation(
+ const string& name) {
+ auto builder = HloComputation::Builder(name);
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
+ auto constv = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, t_s32_f32v4_, "x"));
+ auto indexc = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
+ auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
+ indexc->shape(), HloOpcode::kAdd, indexc, const1));
+ auto indexv = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
+ auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
+ constv->shape(), HloOpcode::kAdd, indexv, constv));
+ builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
+ return builder.Build();
+ }
+
+ // Verifies that the given instruction hlo has a valid input buffer assigned,
+ // i.e., the parameter number matches the op's.
+ const BufferAllocation& GetAssignedInputAllocation(
+ const BufferAssignment& buffers, HloInstruction* hlo) {
+ LOG(INFO) << "Checking input: " << hlo->ToString();
+ const BufferAllocation& buffer =
+ *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie();
+ EXPECT_EQ(hlo->parameter_number(), buffer.parameter_number());
+ return buffer;
+ }
+
+ // Verifies that the given instruction hlo has a valid output buffer
+ // assigned, and returns it.
+ const BufferAllocation& GetAssignedOutputAllocation(
+ const BufferAssignment& buffers, HloInstruction* hlo) {
+ LOG(INFO) << "Checking output: " << hlo->ToString();
+ const BufferAllocation& buffer = GetTopLevelAllocation(buffers, hlo);
+ return buffer;
+ }
+
+ // Returns the allocation for the given instruction.
+ const BufferAllocation& GetAllocation(const BufferAssignment& buffers,
+ const HloInstruction* hlo,
+ const ShapeIndex& index) {
+ return *buffers.GetUniqueAllocation(hlo, index).ConsumeValueOrDie();
+ }
+ const BufferAllocation& GetTopLevelAllocation(const BufferAssignment& buffers,
+ const HloInstruction* hlo) {
+ return *buffers.GetUniqueTopLevelAllocation(hlo).ConsumeValueOrDie();
+ }
+
+ // Verifies that all instructions in the given instruction list except
+ // kConstant have assigned buffers, and returns their total size. If min_index
+ // and max_index are not nullptr, the minimum and maximum buffer indices in
+ // the assignment are written into them.
+ int64 ValidateBuffers(const std::vector<const HloInstruction*>& instructions,
+ const BufferAssignment& buffers) {
+ // Verifies all instructions have buffers, and gets the index ranges.
+ for (const HloInstruction* hlo : instructions) {
+ if (!buffers.HasTopLevelAllocation(hlo)) {
+ // If `hlo` has no assigned buffer, it is either a constant or a nested
+ // parameter.
+ EXPECT_TRUE(HloOpcode::kConstant == hlo->opcode() ||
+ HloOpcode::kParameter == hlo->opcode());
+ continue;
+ }
+ }
+
+ // Gets the total size of all buffers assigned.
+ int64 total_size = 0;
+ for (auto& allocation : buffers.Allocations()) {
+ total_size += allocation.size();
+ }
+ return total_size;
+ }
+
+ // Returns true if the buffers assigned to instructions in "a" are distinct
+ // from the buffers assigned to those in "b" (ie, intersection is empty).
+ bool BuffersDistinct(const std::vector<const HloInstruction*>& a,
+ const std::vector<const HloInstruction*>& b,
+ const BufferAssignment& assignment) {
+ std::set<BufferAllocation::Index> a_buffers;
+ for (const HloInstruction* instruction : a) {
+ if (assignment.HasTopLevelAllocation(instruction)) {
+ a_buffers.insert(assignment.GetUniqueTopLevelAllocation(instruction)
+ .ConsumeValueOrDie()
+ ->index());
+ }
+ }
+
+ for (const HloInstruction* instruction : b) {
+ if (assignment.HasTopLevelAllocation(instruction)) {
+ if (a_buffers.count(assignment.GetUniqueTopLevelAllocation(instruction)
+ .ConsumeValueOrDie()
+ ->index())) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ // Computation tracker for nested computations.
+ ComputationTracker computation_tracker_;
+
+ // Shapes for use in the examples.
+ Shape s32_ = ShapeUtil::MakeShape(xla::S32, {});
+ Shape r0f32_ = ShapeUtil::MakeShape(xla::F32, {});
+ Shape f32vec4_ = ShapeUtil::MakeShape(F32, {4});
+ Shape f32vec10_ = ShapeUtil::MakeShape(F32, {10});
+ Shape f32vec100_ = ShapeUtil::MakeShape(F32, {100});
+ Shape f32a100x10_ = ShapeUtil::MakeShape(F32, {100, 10});
+ Shape t_s32_f32v4_ = ShapeUtil::MakeTupleShape({s32_, f32vec4_});
+ Shape t_s32_f32v10_ = ShapeUtil::MakeTupleShape({s32_, f32vec10_});
+};
+
+namespace {
+std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module) {
+ return BufferAssigner::Run(module, MakeUnique<DependencyHloOrdering>(module),
+ /*pointer_size=*/sizeof(void*))
+ .ConsumeValueOrDie();
+}
+}
+
+// Tests a computation consisting of a single scalar constant node.
+TEST_F(BufferAssignmentTest, ScalarConstant) {
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+ // Check that the constant does not have a buffer assigned.
+ EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
+}
+
+TEST_F(BufferAssignmentTest, BufferForConst) {
+ // Addition of two vector constants: checks that internal constant nodes have
+ // no buffers assigned, and their consumer has a buffer.
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({4.1f, 4.2f, 4.3f, 4.4f})));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1));
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+ // The two constant nodes have no buffers assigned.
+ EXPECT_FALSE(buffers->HasTopLevelAllocation(const0));
+ EXPECT_FALSE(buffers->HasTopLevelAllocation(const1));
+ // The add node has an output buffer.
+ GetAssignedOutputAllocation(*buffers, add);
+}
+
+TEST_F(BufferAssignmentTest, BufferForOutputConst) {
+ // This computation copies a constant to output.
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0));
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+ // The copy node now has an output buffer.
+ GetAssignedOutputAllocation(*buffers, copy);
+}
+
+TEST_F(BufferAssignmentTest, Basic) {
+ // paramscalar ------- (mul) -- (add) -- (sub)
+ // / / /
+ // param0[100] -------/ / /
+ // / /
+ // param1[100] --------------/--------/
+ auto builder = HloComputation::Builder(TestName());
+ auto paramscalar =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, f32vec100_, ""));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32vec100_, ""));
+ auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
+ auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32vec100_, HloOpcode::kSubtract, add, param1));
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+
+ // Distinct input buffers were assigned for parameters.
+ BufferAllocation paramscalar_buffer =
+ GetAssignedInputAllocation(*buffers, paramscalar);
+ BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
+ BufferAllocation param1_buffer = GetAssignedInputAllocation(*buffers, param1);
+ EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
+ EXPECT_NE(paramscalar_buffer.index(), param1_buffer.index());
+ EXPECT_NE(param0_buffer.index(), param1_buffer.index());
+
+ // The mul node has a valid buffer assigned, doesn't share with input.
+ const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
+ EXPECT_NE(mul_buffer.index(), param0_buffer.index());
+
+ // The add node can reuse the mul node's buffer.
+ const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
+ EXPECT_EQ(add_buffer.index(), add_buffer.index());
+
+ // The sub node has a valid output buffer assigned.
+ GetAssignedOutputAllocation(*buffers, sub);
+}
+
+TEST_F(BufferAssignmentTest, MultipleUsersForNode) {
+ // This is similar to the Basic test, with the difference that (sub) is
+ // another user of (mul)'s result, so (mul)'s buffer cannot be reused for
+ // (add)'s output.
+ //
+ // paramscalar -------\ /-----------\
+ // \ / \
+ // param0[100] ------- (mul) -- (add) -- (sub)
+ // /
+ // param1[100] ----------------/
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto paramscalar =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, ""));
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, f32vec100_, ""));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, f32vec100_, ""));
+ auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32vec100_, HloOpcode::kMultiply, paramscalar, param0));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32vec100_, HloOpcode::kAdd, mul, param1));
+ auto sub = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32vec100_, HloOpcode::kSubtract, add, mul));
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+
+ // Input buffers were assigned for parameters.
+ BufferAllocation paramscalar_buffer =
+ GetAssignedInputAllocation(*buffers, paramscalar);
+ BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
+ BufferAllocation param1_index = GetAssignedInputAllocation(*buffers, param1);
+ EXPECT_NE(paramscalar_buffer.index(), param0_buffer.index());
+ EXPECT_NE(paramscalar_buffer.index(), param1_index.index());
+ EXPECT_NE(param0_buffer.index(), param1_index.index());
+
+ // The mul node had a buffer allocated.
+ const BufferAllocation& mul_buffer = GetTopLevelAllocation(*buffers, mul);
+
+ // Now the add node can't reuse the mul node's buffer.
+ const BufferAllocation& add_buffer = GetTopLevelAllocation(*buffers, add);
+ EXPECT_NE(add_buffer.index(), mul_buffer.index());
+
+ // Log size information for inspection.
+ const std::vector<const HloInstruction*> level0 = GetInstructions(sub);
+ int64 size0 = ValidateBuffers(level0, *buffers);
+ LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
+ << " for " << level0.size() << " instructions; "
+ << "total buffer size " << size0;
+}
+
+TEST_F(BufferAssignmentTest, TrivialMap) {
+ // This tests a trivial x+1 map as the only operation.
+ //
+ // param0[100x10] ---> (map x+1)
+ //
+ // Builds the map function.
+ auto module = MakeUnique<HloModule>(TestName());
+ auto map_computation =
+ module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
+ auto inner_last = map_computation->root_instruction();
+
+ // Creates the main kernel and verifies instruction counts.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32a100x10_, ""));
+ auto map = builder.AddInstruction(
+ HloInstruction::CreateMap(f32a100x10_, {param0}, map_computation));
+ const std::vector<const HloInstruction*> level0 = GetInstructions(map);
+ EXPECT_EQ(2, level0.size()) << "Invalid main kernel size";
+ const std::vector<const HloInstruction*> level1 = GetInstructions(inner_last);
+ EXPECT_EQ(3, level1.size()) << "Invalid nested add+1 size";
+
+ module->AddEntryComputation(builder.Build());
+
+ // Assigns buffers and fetches sizes.
+ auto buffers = RunBufferAssignment(module.get());
+ int64 size0 = ValidateBuffers(level0, *buffers);
+ int64 size1 = ValidateBuffers(level1, *buffers);
+
+ // Both algorithms assign the map's buffer before processing the embedded
+ // computation, so we can verify that the buffers aren't shared between them
+ // by checking:
+ EXPECT_TRUE(BuffersDistinct(level0, level1, *buffers))
+ << "Reuse between main kernel and embedded mapping.";
+
+ // An input buffer was assigned for the parameter.
+ BufferAllocation param0_buffer = GetAssignedInputAllocation(*buffers, param0);
+
+ // An output buffer was assigned for the map.
+ BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map);
+ EXPECT_NE(param0_buffer.index(), map_buffer.index());
+
+ // The final computation node of the map is an add of an f32 parm and a
+ // constant.
+ EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode());
+ const BufferAllocation& inner_add_buffer =
+ GetTopLevelAllocation(*buffers, inner_last);
+ EXPECT_NE(inner_add_buffer.index(), map_buffer.index());
+
+ // Log size information for inspection.
+ LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
+ << " for " << level0.size() + level1.size() << " instructions; "
+ << "total buffer size " << size0 + size1;
+}
+
+TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) {
+ // Make sure that the input buffer of a reduce cannot be reused for its
+ // output. (Reuse is not safe in the general case, as it reshapes and some
+ // out-of-order reductions could overwrite an element before a use.)
+ //
+ // param0[100] --- (exp1) --- (exp2) --- (reduce x+1) --- (exp3)
+ auto module = MakeUnique<HloModule>(TestName());
+ auto reduce_computation =
+ module->AddEmbeddedComputation(BuildMapComputationPlus1("f32+1"));
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32a100x10_, ""));
+ auto exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, param0));
+ auto exp2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1));
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
+ /*shape=*/f32vec10_,
+ /*operand=*/exp2,
+ /*init_value=*/const0,
+ /*dimensions_to_reduce=*/{0}, reduce_computation));
+ auto exp3 = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec10_, HloOpcode::kExp, reduce));
+
+ module->AddEntryComputation(builder.Build());
+
+ auto buffers = RunBufferAssignment(module.get());
+ const std::vector<const HloInstruction*> instrs = GetInstructions(exp3);
+ ValidateBuffers(instrs, *buffers);
+
+ const BufferAllocation& exp1_buffer = GetTopLevelAllocation(*buffers, exp1);
+ const BufferAllocation& exp2_buffer = GetTopLevelAllocation(*buffers, exp2);
+ const BufferAllocation& reduce_buffer =
+ GetTopLevelAllocation(*buffers, reduce);
+
+ // The buffer of exp1 is trivially reusable for exp2 - this is just for sanity
+ // checking.
+ EXPECT_EQ(exp1_buffer.index(), exp2_buffer.index());
+
+ // The buffer of exp2 cannot be used for reduce, even though it's the only
+ // operand.
+ EXPECT_NE(exp2_buffer.index(), reduce_buffer.index());
+}
+
+TEST_F(BufferAssignmentTest, ExampleWhile) {
+ // This tests a While loop example from the ir_semantics document.
+ //
+ // condition (s32,f32[4]) -> bool -- see BuildWhileConditionComputation.
+ // body: (s32,f32[4]) -> (s32,f32[4]) -- see BuildWhileBodyComputation.
+ //
+ // const3[s32] -------\
+ // const4[f32[4]] --- tuple --- while[condition, body]
+ //
+ // Builds the nested condition and body.
+ auto module = MakeUnique<HloModule>(TestName());
+ auto condition_computation =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("if<4"));
+ auto body_computation =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("add-update"));
+
+ // Creates the main kernel and verifies instruction counts.
+ auto builder = HloComputation::Builder(TestName());
+ auto const3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(0)));
+ auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 4.4f})));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({const3, const4}));
+ auto while_op = builder.AddInstruction(HloInstruction::CreateWhile(
+ t_s32_f32v4_, condition_computation, body_computation, tuple));
+
+ const std::vector<const HloInstruction*> level0 = GetInstructions(while_op);
+ EXPECT_EQ(4, level0.size()) << "Invalid while kernel size";
+ const std::vector<const HloInstruction*> levelc =
+ GetInstructions(condition_computation->root_instruction());
+ EXPECT_EQ(4, levelc.size()) << "Invalid nested condition size";
+ const std::vector<const HloInstruction*> levelb =
+ GetInstructions(body_computation->root_instruction());
+ EXPECT_EQ(8, levelb.size()) << "Invalid nested body size";
+
+ module->AddEntryComputation(builder.Build());
+
+ // Assigns buffers and fetches sizes.
+ auto buffers = RunBufferAssignment(module.get());
+ int64 size0 = ValidateBuffers(level0, *buffers);
+ int64 sizec = ValidateBuffers(levelc, *buffers);
+ int64 sizeb = ValidateBuffers(levelb, *buffers);
+
+ // BufferAssignment will assign a single allocation for the following
+ // instructions: while, while.cond.param, while.body.param, while.body.result.
+ EXPECT_FALSE(BuffersDistinct(level0, levelc, *buffers))
+ << "Should be reuse between main kernel and embedded condition.";
+ EXPECT_FALSE(BuffersDistinct(levelb, levelc, *buffers))
+ << "Should be reuse between embedded condition and body.";
+ // Expect buffer reuse between main kernel and body computation.
+ EXPECT_FALSE(BuffersDistinct(level0, levelb, *buffers))
+ << "Should be reuse between main kernel and embedded body.";
+
+ // The final computation node of the while body is a tuple of s32 and
+ // f32[4] adds.
+ HloInstruction* body_root = body_computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kTuple, body_root->opcode());
+
+ // Check that buffer for each subshape of 'while_op' shares allocation with
+ // corresponding buffer from while body computation at same index.
+ TF_CHECK_OK(ShapeUtil::ForEachSubshape(
+ while_op->shape(),
+ [this, &buffers, while_op, body_root](const Shape& /*subshape*/,
+ const ShapeIndex& index) {
+ auto while_op_allocation = GetAllocation(*buffers, while_op, index);
+ auto body_root_allocation = GetAllocation(*buffers, body_root, index);
+ EXPECT_EQ(while_op_allocation.index(), body_root_allocation.index());
+ return Status::OK();
+ }));
+
+ // Log size information for inspection.
+ LOG(INFO) << "LogicalBuffer count " << buffers->Allocations().size()
+ << " for " << level0.size() + levelc.size() + levelb.size()
+ << " instructions; total buffer size " << size0 + sizec + sizeb;
+}
+
+TEST_F(BufferAssignmentTest, UnaryOpReuseChain) {
+ // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg)
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, ""));
+ auto exp1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, param0));
+ auto tanh = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kTanh, exp1));
+ auto exp2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kExp, tanh));
+ auto neg = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, exp2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // tanh and exp2 can reuse exp1's buffer
+ EXPECT_TRUE(assignment->HasTopLevelAllocation(exp1));
+ auto& buffer_for_exp1 = GetTopLevelAllocation(*assignment, exp1);
+ EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, tanh));
+ EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, exp2));
+ EXPECT_EQ(buffer_for_exp1, GetTopLevelAllocation(*assignment, neg));
+}
+
+TEST_F(BufferAssignmentTest, ReuseNonOperandBuffer) {
+ // This computation is a chain of operations which decreases in buffer size
+ // (via slice) then increases in size (via broadcast):
+ //
+ // param ---> (negate) ---> (slice) ---> (broadcast)
+ //
+ // The negate should share a buffer with broadcast.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, "param0"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
+ auto slice = builder.AddInstruction(
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // negate and broadcast should share a buffer.
+ EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
+ auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
+ EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
+
+ // Slice should have its own buffer.
+ EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
+}
+
+TEST_F(BufferAssignmentTest, NoReuseLiveBuffer) {
+ // This computation is identical to that in ReuseNonOperandBuffer, but the
+ // negate value is live until the end of the computation (due to it being an
+ // operand of the output tuple) preventing reuse.
+ //
+ // param ---> (negate) ---> (slice) ---> (broadcast)-> (tuple)
+ // \-----------------------------------/
+ //
+ // The negate should not share a buffer with broadcast.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, "param0"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
+ auto slice = builder.AddInstruction(
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
+ builder.AddInstruction(HloInstruction::CreateTuple({negate, broadcast}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // The instructions should not share buffers.
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, negate));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, slice));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
+ GetTopLevelAllocation(*assignment, slice));
+}
+
+TEST_F(BufferAssignmentTest, NoReuseAliasedBuffer) {
+ // This computation is identical to that in ReuseNonOperandBuffer, but the
+ // negate value is placed into a tuple which lives to the end of the
+ // computation. This extends the live range of negate's buffer preventing
+ // reuse due to buffer aliasing.
+ //
+ // param ---> (negate) ---> (tuple) -> (slice) ---> (broadcast)-> (tuple)
+ // \-----------------------------------/
+ //
+ // The negate should not share a buffer with broadcast.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, "param0"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({negate}));
+ auto tuple_element = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(f32vec100_, tuple, 0));
+ auto slice = builder.AddInstruction(
+ HloInstruction::CreateSlice(f32vec10_, tuple_element, {0}, {10}));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32a100x10_, slice, {1}));
+ builder.AddInstruction(HloInstruction::CreateTuple({tuple, broadcast}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // The instructions should not share buffers.
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, negate));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, slice));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
+ GetTopLevelAllocation(*assignment, slice));
+}
+
+TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBuffer) {
+ // This computation is very similar to ReuseNonOperandBuffer except the
+ // broadcast has a smaller output than the negate. This should block reuse of
+ // negate's buffer by broadcast because the output buffer(s) of a computation
+ // should be exactly sized for the value.
+ //
+ // param ---> (negate) ---> (slice) ---> (broadcast)
+ //
+ // The negate should *not* share a buffer with broadcast.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, "param0"));
+ // Negate output is 100 elements.
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
+ auto slice = builder.AddInstruction(
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ // Broadcast output is 40 elements.
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // The instructions should not share buffers.
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, negate));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, slice));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
+ GetTopLevelAllocation(*assignment, slice));
+}
+
+TEST_F(BufferAssignmentTest, ReuseOutputBufferIfExactlySized) {
+ // This is identical to DoNotReuseOversizedOutputBuffer except the broadcast
+ // output is exactly the same size as the negate (rather than being
+ // smaller). This enables reuse of negate's buffer by the broadcast because
+ // the output buffer will be sized exactly to its value.
+ //
+ // param ---> (negate) ---> (slice) ---> (broadcast)
+ //
+ // The negate should *not* share a buffer with broadcast.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, "param0"));
+ // Negate output is 100 elements.
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
+ auto slice = builder.AddInstruction(
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ // Broadcast output is 40 elements.
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {10, 10}), slice, {0}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // negate and broadcast should share a buffer.
+ EXPECT_TRUE(assignment->HasTopLevelAllocation(broadcast));
+ auto& buffer_for_bcast = GetTopLevelAllocation(*assignment, broadcast);
+ EXPECT_EQ(buffer_for_bcast, GetTopLevelAllocation(*assignment, negate));
+
+ // Slice should have its own buffer.
+ EXPECT_NE(buffer_for_bcast, GetTopLevelAllocation(*assignment, slice));
+}
+
+TEST_F(BufferAssignmentTest, DoNotReuseOversizedOutputBufferInTuple) {
+ // This computation is very similar to ReuseNonOperandBuffer except the
+ // broadcast has a smaller output than the negate, and the broadcast is
+ // contained in the computation output as a tuple element. This should block
+ // reuse of the negate's buffer by the broadcast because the output buffer(s)
+ // of a computation should be exactly sized for the value. This includes those
+ // buffers aliased in the output (eg, contained as tuple elements).
+ //
+ // param ---> (negate) ---> (slice) ---> (broadcast) --> (tuple)
+ //
+ // The negate should *not* share a buffer with broadcast.
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32vec100_, "param0"));
+ // Negate output is 100 elements.
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32vec100_, HloOpcode::kNegate, param0));
+ auto slice = builder.AddInstruction(
+ HloInstruction::CreateSlice(f32vec10_, negate, {0}, {10}));
+ // Broadcast output is 40 elements.
+ auto broadcast = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {10, 4}), slice, {0}));
+ builder.AddInstruction(HloInstruction::CreateTuple({broadcast}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // The instructions should not share buffers.
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, negate));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, broadcast),
+ GetTopLevelAllocation(*assignment, slice));
+ EXPECT_NE(GetTopLevelAllocation(*assignment, negate),
+ GetTopLevelAllocation(*assignment, slice));
+}
+
+TEST_F(BufferAssignmentTest, EmbeddedComputationBuffers) {
+ // Verify that buffers for embedded computations are properly marked as
+ // thread-local and that embedded parameters are not marked as
+ // is_entry_computation_parameter.
+ auto module = MakeUnique<HloModule>(TestName());
+ auto vec_shape = ShapeUtil::MakeShape(F32, {42});
+ auto scalar_shape = ShapeUtil::MakeShape(F32, {});
+
+ // Create a scalar computation to use in a map.
+ auto map_builder = HloComputation::Builder(TestName() + "_map");
+ auto map_param = map_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "map_param"));
+ auto map_root = map_builder.AddInstruction(
+ HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, map_param));
+ auto map_computation = module->AddEmbeddedComputation(map_builder.Build());
+
+ // Create a vector computation to use in a kCall.
+ auto call_builder = HloComputation::Builder(TestName() + "_call");
+ auto call_param = call_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vec_shape, "vec_param"));
+ auto call_root = call_builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_shape, HloOpcode::kExp, call_param));
+ auto call_computation = module->AddEmbeddedComputation(call_builder.Build());
+
+ // Create entry computation which kCalls call_computation and then calls map
+ // with map_computation on the result.
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vec_shape, "param"));
+ auto call = builder.AddInstruction(
+ HloInstruction::CreateCall(vec_shape, {param}, call_computation));
+ auto map = builder.AddInstruction(
+ HloInstruction::CreateMap(vec_shape, {call}, map_computation));
+ module->AddEntryComputation(builder.Build());
+
+ auto assignment = RunBufferAssignment(module.get());
+
+ // Allocations for the map computation should be thread-local and not
+ // live-out.
+ auto& map_param_alloc = GetTopLevelAllocation(*assignment, map_param);
+ EXPECT_FALSE(map_param_alloc.is_entry_computation_parameter());
+ EXPECT_FALSE(map_param_alloc.maybe_live_out());
+ EXPECT_TRUE(map_param_alloc.is_thread_local());
+
+ auto& map_root_alloc = GetTopLevelAllocation(*assignment, map_root);
+ EXPECT_FALSE(map_root_alloc.is_entry_computation_parameter());
+ EXPECT_FALSE(map_root_alloc.maybe_live_out());
+ EXPECT_TRUE(map_root_alloc.is_thread_local());
+
+ // Allocations for the call computation should not be thread-local and not
+ // live-out.
+ auto& call_param_alloc = GetTopLevelAllocation(*assignment, call_param);
+ EXPECT_FALSE(call_param_alloc.is_entry_computation_parameter());
+ EXPECT_FALSE(call_param_alloc.maybe_live_out());
+ EXPECT_FALSE(call_param_alloc.is_thread_local());
+
+ auto& call_root_alloc = GetTopLevelAllocation(*assignment, call_root);
+ EXPECT_FALSE(call_root_alloc.is_entry_computation_parameter());
+ EXPECT_FALSE(call_root_alloc.maybe_live_out());
+ EXPECT_FALSE(call_root_alloc.is_thread_local());
+
+ // Entry computation allocations can be marked liveout and
+ // is_entry_computation_parameter.
+ auto& param_alloc = GetTopLevelAllocation(*assignment, param);
+ EXPECT_TRUE(param_alloc.is_entry_computation_parameter());
+ EXPECT_FALSE(param_alloc.maybe_live_out());
+ EXPECT_FALSE(param_alloc.is_thread_local());
+
+ auto& map_alloc = GetTopLevelAllocation(*assignment, map);
+ EXPECT_FALSE(map_alloc.is_entry_computation_parameter());
+ EXPECT_TRUE(map_alloc.maybe_live_out());
+ EXPECT_FALSE(map_alloc.is_thread_local());
+}
+
+TEST_F(BufferAssignmentTest, TupleParameterAsOutput) {
+ // Test a computation that returns a tuple parameter.
+ auto builder = HloComputation::Builder(TestName());
+ auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
+ ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeShape(S32, {42})}),
+ "param0"));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // There should be four allocations: one for vector of pointers, and one for
+ // each tuple element.
+ EXPECT_EQ(4, assignment->Allocations().size());
+
+ // Verify each buffer allocation is marked as an entry computation parameter
+ // and is liveout.
+ TF_CHECK_OK(ShapeUtil::ForEachSubshape(
+ tuple_param->shape(),
+ [this, &assignment, tuple_param](const Shape& /*subshape*/,
+ const ShapeIndex& index) {
+ auto allocation = GetAllocation(*assignment, tuple_param, index);
+ EXPECT_TRUE(allocation.is_entry_computation_parameter());
+ EXPECT_EQ(0, allocation.parameter_number());
+ EXPECT_TRUE(allocation.maybe_live_out());
+ return Status::OK();
+ }));
+}
+
+TEST_F(BufferAssignmentTest, ElementOfNestedTupleParameterAsOutput) {
+ // Test a computation which returns a GetElementTuple of a nested tuple
+ // parameter.
+ auto builder = HloComputation::Builder(TestName());
+ auto tuple_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {42}),
+ ShapeUtil::MakeShape(S32, {101})})}),
+ "param0"));
+ auto tuple_element =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(tuple_param->shape(), {1}), tuple_param, 1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // Only some of the elements of the input param are liveout.
+ EXPECT_FALSE(
+ GetAllocation(*assignment, tuple_param, /*index=*/{}).maybe_live_out());
+ // Tuple element at index={1} is live out because GetTupleElement({1})
+ // forwards a pointer to this allocation (instead of defining its own buffer).
+ EXPECT_TRUE(
+ GetAllocation(*assignment, tuple_param, /*index=*/{1}).maybe_live_out());
+ EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0})
+ .maybe_live_out());
+ EXPECT_TRUE(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1})
+ .maybe_live_out());
+
+ // The GetTupleElement output is liveout.
+ EXPECT_TRUE(
+ GetTopLevelAllocation(*assignment, tuple_element).maybe_live_out());
+
+ // Verify that the GetTupleElement allocations of its elements match the
+ // corresponding tuple parameter allocations because they alias.
+ EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 0}),
+ GetAllocation(*assignment, tuple_element, /*index=*/{0}));
+ EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1, 1}),
+ GetAllocation(*assignment, tuple_element, /*index=*/{1}));
+
+ // GetTupleElement forwards a pointer to its underlying buffer, so verify
+ // that it has the same allocation than the corresponding parameter element.
+ EXPECT_EQ(GetAllocation(*assignment, tuple_param, /*index=*/{1}),
+ GetTopLevelAllocation(*assignment, tuple_element));
+}
+
+// TODO(b/32248867): Enable when buffer assignment gives allocations to
+// constants.
+TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) {
+ // Test that a tuple constant which is forwarded to the computation output is
+ // properly handled.
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
+ LiteralUtil::CreateR0<int64>(1).get()})));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ EXPECT_EQ(3, assignment->Allocations().size());
+}
+
+TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
+ // Test a computation which returns a tuple custom call value.
+ auto builder = HloComputation::Builder(TestName());
+ auto custom_call = builder.AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4}),
+ ShapeUtil::MakeShape(S32, {101})}),
+ /*operands=*/{}, /*custom_call_target=*/"foo_function"));
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ EXPECT_EQ(3, assignment->Allocations().size());
+ EXPECT_TRUE(
+ GetAllocation(*assignment, custom_call, /*index=*/{}).maybe_live_out());
+ EXPECT_TRUE(
+ GetAllocation(*assignment, custom_call, /*index=*/{0}).maybe_live_out());
+ EXPECT_TRUE(
+ GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
+}
+
+TEST_F(BufferAssignmentTest, BitcastAsOutput) {
+ // Test a computation which returns a bitcast value.
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {42}), "param"));
+ auto bitcast = builder.AddInstruction(
+ HloInstruction::CreateUnary(param->shape(), HloOpcode::kBitcast, param));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // Bitcast should get the same allocation as the param.
+ EXPECT_EQ(1, assignment->Allocations().size());
+ EXPECT_EQ(GetTopLevelAllocation(*assignment, param),
+ GetTopLevelAllocation(*assignment, bitcast));
+}
+
+TEST_F(BufferAssignmentTest, AmbiguousBufferAsOutput) {
+ // Test a computation with an output that has an ambiguous points-to set. This
+ // is constructed using a select among tuple shapes.
+ auto builder = HloComputation::Builder(TestName());
+ auto tuple_shape =
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(PRED, {1, 2, 3, 4})});
+
+ auto tuple_param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param0"));
+ auto tuple_param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, tuple_shape, "param1"));
+ auto pred_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 2, ShapeUtil::MakeShape(PRED, {}), "param1"));
+ auto select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple_shape, HloOpcode::kSelect, pred_param, tuple_param0, tuple_param1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ auto assignment = RunBufferAssignment(module.get());
+
+ // Select shallow copies one of its operands so it defines its own top-level
+ // buffer and receives its own allocation.
+ auto select_alloc = GetTopLevelAllocation(*assignment, select);
+ EXPECT_EQ(1, select_alloc.assigned_buffers().size());
+ EXPECT_EQ(select, select_alloc.assigned_buffers()[0]->instruction());
+
+ // The buffer for the tuple element of the select is forwarded from one its
+ // operands which cannot be determined statically. Therefore its allocation
+ // should include the allocations of both of the elements in the parameters.
+ auto element_allocations = assignment->GetAllocations(select, /*index=*/{0});
+ EXPECT_EQ(2, element_allocations.size());
+ EXPECT_MATCH(testing::SetToVec<BufferAllocation>(element_allocations),
+ testing::UnorderedMatcher<BufferAllocation>(
+ *assignment->GetUniqueAllocation(tuple_param0, /*index=*/{0})
+ .ConsumeValueOrDie(),
+ *assignment->GetUniqueAllocation(tuple_param1, /*index=*/{0})
+ .ConsumeValueOrDie()));
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc
new file mode 100644
index 0000000000..d4faca7cd8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_liveness.cc
@@ -0,0 +1,259 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Defines the data returned by the XLA buffer assignment packages.
+
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module)
+ : module_(module) {}
+
+bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a,
+ const HloInstruction* b) const {
+ // Instructions in different computations are unordered.
+ if (a->parent() != b->parent()) {
+ return false;
+ }
+ // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'.
+ return strict_predecessors_.at(b->parent())->IsReachable(b, a);
+}
+
+string PredecessorHloOrdering::ToStringHelper(const string& name) const {
+ std::vector<string> pieces;
+ pieces.push_back(name);
+ for (auto& computation : module_->computations()) {
+ pieces.push_back(tensorflow::strings::Printf("computation %s:",
+ computation->name().c_str()));
+ const auto all = computation->MakeInstructionPostOrder();
+ for (auto instruction : all) {
+ pieces.push_back(tensorflow::strings::Printf(
+ " %s strict predecessors:", instruction->name().c_str()));
+ for (auto predecessor : all) {
+ if (strict_predecessors_.at(computation.get())
+ ->IsReachable(instruction, predecessor)) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" %s", predecessor->name().c_str()));
+ }
+ }
+ }
+ }
+ return tensorflow::str_util::Join(pieces, "\n");
+}
+
+DependencyHloOrdering::DependencyHloOrdering(const HloModule* module)
+ : PredecessorHloOrdering(module) {
+ // Compute predecessor relationships between all instructions to determine
+ // ordering based on dependencies. ExecutesBefore will return true iff there
+ // exists a path in the HLO computation graph from 'a' to 'b'.
+ for (auto& computation : module->computations()) {
+ strict_predecessors_.emplace(computation.get(),
+ computation->ComputeTransitiveOperands());
+ }
+}
+
+string DependencyHloOrdering::ToString() const {
+ return ToStringHelper("DependencyHloOrdering");
+}
+
+SequentialHloOrdering::SequentialHloOrdering(
+ const HloModule* module, const HloModuleSequence& module_sequence)
+ : module_(module) {
+ // Create a map from instruction to its order position.
+ for (auto computation_order : module_sequence) {
+ const std::vector<const HloInstruction*>& order = computation_order.second;
+ for (int i = 0; i < order.size(); ++i) {
+ DCHECK_EQ(0, order_position_.count(order[i]));
+ order_position_.emplace(order[i], i);
+ }
+ }
+}
+
+bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a,
+ const HloInstruction* b) const {
+ // Instructions in different computations are unordered.
+ if (a->parent() != b->parent()) {
+ return false;
+ }
+ // If either instruction is not in the order, then 'a' and 'b' are unordered.
+ if (order_position_.count(a) == 0 || order_position_.count(b) == 0) {
+ return false;
+ }
+ return order_position_.at(a) < order_position_.at(b);
+}
+
+string SequentialHloOrdering::ToString() const {
+ std::vector<string> pieces;
+ pieces.push_back("SequentialHloOrdering");
+ for (auto& computation : module_->computations()) {
+ pieces.push_back(tensorflow::strings::Printf("computation %s order:",
+ computation->name().c_str()));
+ // Gather all instructions in the module sequence for this computation and
+ // sort them by their position.
+ std::vector<const HloInstruction*> instructions;
+ for (auto& instruction_position : order_position_) {
+ const HloInstruction* instruction = instruction_position.first;
+ if (instruction->parent() == computation.get()) {
+ instructions.push_back(instruction);
+ }
+ }
+ std::sort(instructions.begin(), instructions.end(),
+ [this](const HloInstruction* a, const HloInstruction* b) {
+ return order_position_.at(a) < order_position_.at(b);
+ });
+ for (auto instruction : instructions) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" %s", instruction->name().c_str()));
+ }
+ }
+ return tensorflow::str_util::Join(pieces, "\n");
+}
+
+/* static */
+StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering) {
+ std::unique_ptr<BufferLiveness> liveness(
+ new BufferLiveness(module, std::move(hlo_ordering)));
+ TF_RETURN_IF_ERROR(liveness->Analyze());
+ return std::move(liveness);
+}
+
+tensorflow::Status BufferLiveness::Analyze() {
+ TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_));
+ for (auto& computation : module_->computations()) {
+ // Gather all instructions whose buffers might alias other instructions into
+ // the set aliased_buffers_. This includes those contained as a tuple
+ // element in other instruction's output.
+ for (const auto& instruction : computation->instructions()) {
+ for (const LogicalBuffer* aliased_buffer :
+ points_to_analysis_->GetPointsToSet(instruction.get())
+ .CreateFlattenedSet()) {
+ if (aliased_buffer->instruction() != instruction.get()) {
+ aliased_buffers_.insert(aliased_buffer);
+ }
+ }
+ }
+
+ if (computation.get() == module_->entry_computation()) {
+ for (const LogicalBuffer* live_out_buffer :
+ points_to_analysis_->GetPointsToSet(computation->root_instruction())
+ .CreateFlattenedSet()) {
+ maybe_live_out_buffers_.insert(live_out_buffer);
+ }
+ }
+ }
+
+ XLA_VLOG_LINES(3, ToString());
+ return tensorflow::Status::OK();
+}
+
+string BufferLiveness::ToString() const {
+ std::vector<string> pieces;
+ pieces.push_back(tensorflow::strings::Printf("BufferLiveness(module=%s):",
+ module_->name().c_str()));
+ pieces.push_back("HloOrdering:");
+ pieces.push_back(hlo_ordering_->ToString());
+ pieces.push_back(tensorflow::strings::Printf("Aliased buffers:"));
+ for (const LogicalBuffer* buffer : aliased_buffers_) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
+ }
+ pieces.push_back(tensorflow::strings::Printf("Live out buffers:"));
+ for (const LogicalBuffer* buffer : maybe_live_out_buffers_) {
+ pieces.push_back(
+ tensorflow::strings::Printf(" %s", buffer->ToString().c_str()));
+ }
+ return tensorflow::str_util::Join(pieces, "\n");
+}
+
+// Returns false if 'user' cannot possibly use the buffer at 'index' in
+// 'operand'. Returns true otherwise.
+// Precondition: 'operand' is an operand of 'user'.
+bool MayUseBufferInOperand(HloInstruction* operand, const ShapeIndex& index,
+ HloInstruction* user) {
+ if (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()) {
+ // GetTupleElement instructions only access the top-level buffer of their
+ // operand.
+ return false;
+ }
+ return true;
+}
+
+bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
+ const LogicalBuffer& b) const {
+ TF_CHECK_OK(points_to_analysis_->VerifyBuffer(a));
+ TF_CHECK_OK(points_to_analysis_->VerifyBuffer(b));
+
+ if (!hlo_ordering_->ExecutesBefore(a.instruction(), b.instruction())) {
+ return false;
+ }
+
+ // Every user of 'a' must be a predecessor of 'b' or 'b' itself.
+ for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
+ for (auto user : alias.instruction()->users()) {
+ if (!MayUseBufferInOperand(alias.instruction(), alias.index(), user)) {
+ continue;
+ }
+ if (user != b.instruction() &&
+ !hlo_ordering_->ExecutesBefore(user, b.instruction())) {
+ return false;
+ }
+ }
+ }
+
+ // If 'b' is a user of 'a' then the buffers interfere if b is not an
+ // elementwise operation emitting the same shape/layout as 'a'.
+ for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
+ if (alias.instruction()->users().count(b.instruction()) > 0 &&
+ (!ShapeUtil::Equal(alias.instruction()->shape(),
+ b.instruction()->shape()) ||
+ !b.instruction()->IsElementwise())) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool BufferLiveness::MayInterfere(const LogicalBuffer& a,
+ const LogicalBuffer& b) const {
+ return (!live_range_strictly_before(a, b) &&
+ !live_range_strictly_before(b, a));
+}
+
+bool BufferLiveness::MaybeLiveOut(const LogicalBuffer& buffer) const {
+ // Verify that a buffer is actually defined at the given instruction/index
+ // (eg, its not an alias of another buffer such as occurs with a bitcast).
+ TF_CHECK_OK(points_to_analysis_->VerifyBuffer(buffer));
+ return maybe_live_out_buffers_.count(&buffer);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h
new file mode 100644
index 0000000000..964f558c8c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_liveness.h
@@ -0,0 +1,215 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
+
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace xla {
+
+// Abstract base class for describing a partial ordering of HLO
+// instructions. Used to determine live range overlap of HLO instruction output
+// buffers.
+class HloOrdering {
+ public:
+ HloOrdering() = default;
+ virtual ~HloOrdering() = default;
+
+ // Returns true if instruction 'a' executes before instruction 'b'. This is
+ // not reflexive, that is, an instruction does not execute before itself.
+ virtual bool ExecutesBefore(const HloInstruction* a,
+ const HloInstruction* b) const = 0;
+ virtual string ToString() const = 0;
+};
+
+// Base class for partial orderings implemented by a map of strict predecessors
+// for each instruction. Subclasses should fill in strict_predecessors_.
+class PredecessorHloOrdering : public HloOrdering {
+ public:
+ ~PredecessorHloOrdering() override = default;
+
+ // Returns true if instruction 'a' executes before instruction 'b'.
+ // Instructions in different computations are not ordered.
+ bool ExecutesBefore(const HloInstruction* a,
+ const HloInstruction* b) const override;
+
+ protected:
+ explicit PredecessorHloOrdering(const HloModule* module);
+ string ToStringHelper(const string& name) const;
+
+ const HloModule* module_;
+
+ // For each each computation in the module, this is the set of the
+ // instruction's strict predecessors. An instruction is not an element of its
+ // own strict predecessor set.
+ //
+ // Subclasses should fill this in to define the desired ordering.
+ tensorflow::gtl::FlatMap<const HloComputation*,
+ std::unique_ptr<HloComputation::ReachabilityMap>>
+ strict_predecessors_;
+};
+
+// An HLO ordering based on data dependencies in the HLO graph. In this partial
+// order, instruction A executes before instruction B only if there is a path
+// from A to B in the HLO graph. For example, given the following graph:
+//
+// param
+// / \
+// negate exp
+// \ /
+// add
+//
+// DependencyHloOrdering gives the following executes-before relations:
+// param executes before negate, exp, and add
+// negate executes before add
+// exp executes before add
+// add executes before nothing
+// negate and exp are not ordered because the dependencies allow either to
+// execute before the other (or in parallel). DependencyHloOrdering ordering
+// allows maximum parallelism and enables any execution order which satisfies
+// data dependencies. This requires pessimistic assumptions about buffer live
+// ranges and can result in more memory used than more constrained orderings.
+class DependencyHloOrdering : public PredecessorHloOrdering {
+ public:
+ explicit DependencyHloOrdering(const HloModule* module);
+ ~DependencyHloOrdering() override = default;
+
+ string ToString() const override;
+};
+
+// An HLO ordering based on a total order of instructions in each computation.
+// The computation total order is a sequencing of all of its instructions in
+// the computation (eg, {inst0, inst1, inst2,...}) as in single-threaded
+// execution. For example, given the following HLO graph:
+//
+// param
+// / \
+// negate exp
+// \ /
+// add
+//
+// and the following sequence:
+//
+// {param, negate, exp, add}
+//
+// SequentialHloOrdering gives the following executes-before relations:
+// param executes before negate, exp, and add
+// negate executes before exp and add
+// exp executes before add
+// add executes before nothing
+// This is more constrained than DependencyHloOrdering in this example because
+// negate and exp are ordered (negate before exp). This enables param to share
+// the same buffer as exp (param buffer is dead after exp). Generally, this
+// ordering enables more buffer sharing (reduced memory usage) because buffer
+// interference is reduced relative to DependencyHloOrdering.
+class SequentialHloOrdering : public HloOrdering {
+ public:
+ // A sequence of instructions for each computation in the module.
+ using HloModuleSequence =
+ tensorflow::gtl::FlatMap<const HloComputation*,
+ std::vector<const HloInstruction*>>;
+
+ SequentialHloOrdering(const HloModule* module,
+ const HloModuleSequence& module_sequence);
+ ~SequentialHloOrdering() override = default;
+
+ // Instruction 'a' executes before 'b' if 'a' appears before 'b' in the
+ // instruction sequence for the computation. Instructions in different
+ // computations are unordered.
+ bool ExecutesBefore(const HloInstruction* a,
+ const HloInstruction* b) const override;
+ string ToString() const override;
+
+ protected:
+ const HloModule* module_;
+
+ // The position of every instruction in the HLO module in its respective
+ // computation sequence (a value of zero indicates the instruction is first in
+ // the sequence, etc). Instructions from all computations are contained in
+ // this map so more than one instruction may have the same position
+ // value. This is not a problem because ExecutesBefore also verifies
+ // instructions are in the same computation.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
+};
+
+// Class which computes liveness of the output buffers of HLOs and their
+// interference.
+class BufferLiveness {
+ public:
+ // Constructs a buffer liveness object for the given module assuming the given
+ // HLO instruction ordering.
+ static StatusOr<std::unique_ptr<BufferLiveness>> Run(
+ const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering);
+
+ // Returns true if the live range of the buffer containing the output of 'a'
+ // may overlap with the live range of the buffer of 'b'. If instruction 'a'
+ // interferes with instruction 'b' then they cannot share the same buffer.
+ bool MayInterfere(const LogicalBuffer& a, const LogicalBuffer& b) const;
+
+ // Returns true if the buffer for the given instruction may be live out of the
+ // module. That is, the instruction's buffer may be included in the output of
+ // the entry computation.
+ bool MaybeLiveOut(const LogicalBuffer& buffer) const;
+
+ // Returns the underlying points-to analysis used for this liveness analysis.
+ const TuplePointsToAnalysis& points_to_analysis() const {
+ return *points_to_analysis_;
+ }
+
+ string ToString() const;
+
+ private:
+ explicit BufferLiveness(const HloModule* module,
+ std::unique_ptr<HloOrdering> hlo_ordering)
+ : module_(module), hlo_ordering_(std::move(hlo_ordering)) {}
+
+ // Perform buffer liveness analysis. This method must be called prior to
+ // MayInterfere or MaybeLiveOut.
+ tensorflow::Status Analyze();
+
+ // Returns true if the live range of the buffer of 'a' is strictly before the
+ // live range of the buffer of 'b' (they do not overlap).
+ bool live_range_strictly_before(const LogicalBuffer& a,
+ const LogicalBuffer& b) const;
+
+ const HloModule* module_;
+ std::unique_ptr<HloOrdering> hlo_ordering_;
+
+ // Set of LogicalBuffers which are aliased in the output of other
+ // instructions. For example, a LogicalBuffer which is inserted into a tuple
+ // is considered to be aliased and will be in this set.
+ tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
+
+ // LogicalBuffers that may be live out of the entry computation.
+ tensorflow::gtl::FlatSet<const LogicalBuffer*> maybe_live_out_buffers_;
+
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_LIVENESS_H_
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
new file mode 100644
index 0000000000..1ca5768dbe
--- /dev/null
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -0,0 +1,487 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class BufferLivenessTest : public HloTestBase {
+ protected:
+ // Returns the LogicalBuffer defined at the given instruction and
+ // index. CHECKs if no buffer is defined at that point.
+ const LogicalBuffer& GetBuffer(const BufferLiveness& liveness,
+ const HloInstruction* instruction,
+ const ShapeIndex& index) {
+ const std::vector<const LogicalBuffer*>& pointed_to =
+ liveness.points_to_analysis()
+ .GetPointsToSet(instruction)
+ .element(index);
+ CHECK_EQ(1, pointed_to.size());
+ CHECK_EQ(instruction, pointed_to[0]->instruction());
+ CHECK(index == pointed_to[0]->index());
+ return *pointed_to[0];
+ }
+
+ // Returns true if the top-level buffers for instructions 'a' and 'b' may
+ // interfere. Precondition: 'a' and 'b' are array-shaped.
+ bool InstructionsMayInterfere(const BufferLiveness& liveness,
+ HloInstruction* a, HloInstruction* b) {
+ EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
+ EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
+ return liveness.MayInterfere(
+ GetBuffer(liveness, /*instruction=*/a, /*index=*/{}),
+ GetBuffer(liveness, /*instruction=*/b, /*index=*/{}));
+ }
+
+ // Returns true if the tuple elements at 'index' for instructions 'a' and 'b'
+ // may interfere. Precondition: 'a' and 'b' are tuple-shaped, with equal
+ // tuple element sub-shapes.
+ bool TupleElementsMayInterfere(const BufferLiveness& liveness,
+ HloInstruction* a, HloInstruction* b,
+ const ShapeIndex& index) {
+ // Check that top-level shapes are tuple and tuple element shapes are equal.
+ EXPECT_TRUE(ShapeUtil::IsTuple(a->shape()));
+ EXPECT_TRUE(ShapeUtil::IsTuple(b->shape()));
+ EXPECT_TRUE(
+ ShapeUtil::Compatible(ShapeUtil::GetSubshape(a->shape(), index),
+ ShapeUtil::GetSubshape(b->shape(), index)));
+ // Lookup PointsTo set for instructions 'a' and 'b'.
+ auto& points_to_analysis = liveness.points_to_analysis();
+ const std::vector<const LogicalBuffer*>& points_to_a =
+ points_to_analysis.GetPointsToSet(a).element(index);
+ const std::vector<const LogicalBuffer*>& points_to_b =
+ points_to_analysis.GetPointsToSet(b).element(index);
+ // Make sure PointsTo sets for 'a' and 'b' are unambiguous.
+ EXPECT_EQ(1, points_to_a.size());
+ EXPECT_EQ(points_to_a.size(), points_to_b.size());
+ // Check interference.
+ return liveness.MayInterfere(*points_to_a[0], *points_to_b[0]);
+ }
+
+ // Returns true if the top-level buffers for the given instruction maybe
+ // liveout of the entry computation.
+ // Precondition: instruction is array-shaped.
+ bool InstructionMaybeLiveOut(const BufferLiveness& liveness,
+ HloInstruction* instruction) {
+ return liveness.MaybeLiveOut(
+ GetBuffer(liveness, instruction, /*index=*/{}));
+ }
+
+ const Shape vec_ = ShapeUtil::MakeShape(xla::F32, {42});
+};
+
+TEST_F(BufferLivenessTest, ElementwiseChain) {
+ // A simple chain of elementwise operations. No buffers should interfere.
+ //
+ // param --> negate -> exp -> log
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
+ auto log = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kLog, exp));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // No buffers should interfere.
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, negate, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, exp, log));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, log));
+
+ // Buffers should interfere with itself.
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, exp, exp));
+
+ // Only log is live out.
+ EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
+ EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, negate));
+ EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, exp));
+ EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, log));
+}
+
+TEST_F(BufferLivenessTest, NonElementwiseOperand) {
+ // A chain of operations with one elementwise and one non-elementwise. The
+ // elementwise op should not interfere with its operand, while the
+ // non-elementwise op should interfere.
+ //
+ // param --> negate -> reverse
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
+ auto reverse =
+ builder.AddInstruction(HloInstruction::CreateReverse(vec_, negate, {0}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // No buffers should interfere.
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, reverse));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
+}
+
+TEST_F(BufferLivenessTest, OverlappedBuffers) {
+ // Verify simultaneously live buffers interfere (exp and negate).
+ //
+ // param --> negate -> add
+ // \---> exp -----/
+ //
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, exp));
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
+}
+
+TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
+ // Identical to the test OverlappedBuffer but using a sequential ordering of
+ // HLO instructions.
+ //
+ // param --> negate -> add
+ // \---> exp -----/
+ //
+ // Sequential order:
+ // param, negate, exp, add
+ //
+ // Liveness is identical to the DependencyHloOrdering except that 'param' and
+ // exp no longer interfere.
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kExp, param));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, negate, exp));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ SequentialHloOrdering::HloModuleSequence module_sequence;
+ std::vector<const HloInstruction*> order = {param, negate, exp, add};
+ module_sequence.emplace(computation, order);
+ auto liveness =
+ BufferLiveness::Run(module.get(), MakeUnique<SequentialHloOrdering>(
+ module.get(), module_sequence))
+ .ConsumeValueOrDie();
+
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, negate, exp));
+ EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add));
+}
+
+TEST_F(BufferLivenessTest, TupleLiveOut) {
+ // Verify MaybeLiveOut with nested tuples. Result of computation looks like:
+ //
+ // Tuple({Tuple({Negate(Param)}, Exp(Negate(Param)))})
+ //
+ // All values should be live out except Param.
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kNegate, param));
+ auto inner_tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({negate}));
+ auto exp = builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kExp, negate));
+ auto outer_tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple, exp}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // All buffers should be live out except the param
+ EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, param));
+ EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, negate));
+ EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, inner_tuple));
+ EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, exp));
+ EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, outer_tuple));
+}
+
+// bitcast liveout.
+
+TEST_F(BufferLivenessTest, EmbeddedComputation) {
+ // Test MaybeLiveOut and MayInterfere for embedded computation.
+ auto module = MakeUnique<HloModule>(TestName());
+
+ auto embedded_builder = HloComputation::Builder(TestName() + "_embedded");
+ auto embedded_param = embedded_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, vec_, "embedded_param"));
+ auto embedded_log = embedded_builder.AddInstruction(
+ HloInstruction::CreateUnary(vec_, HloOpcode::kLog, embedded_param));
+
+ auto embedded_computation =
+ module->AddEmbeddedComputation(embedded_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param"));
+ auto call = builder.AddInstruction(
+ HloInstruction::CreateCall(vec_, {param}, embedded_computation));
+
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // Buffers in different computations should always interfere.
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_log, call));
+ EXPECT_TRUE(InstructionsMayInterfere(*liveness, embedded_param, param));
+ EXPECT_FALSE(
+ InstructionsMayInterfere(*liveness, embedded_param, embedded_log));
+
+ // The only buffers for which MaybeLiveOut == true are those live out
+ // of the entry computation. Buffers live out of embedded computations should
+ // return false for this method.
+ EXPECT_FALSE(InstructionMaybeLiveOut(*liveness, embedded_log));
+ EXPECT_TRUE(InstructionMaybeLiveOut(*liveness, call));
+}
+
+TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
+ // Verify non top-level elements of a nested tuple constant are properly
+ // marked as liveout. Computation:
+ //
+ // GetTupleElement(0, TupleConstant({{0, 1}, {3}})
+ //
+ // Only the array buffers containing 0 and 1 are liveout of the
+ // computation. The buffer containing {0, 1} is copied by GetTupleElement, and
+ // the buffers containing {3} and 3 are dead.
+ auto builder = HloComputation::Builder(TestName());
+ auto inner_tuple0 =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(0).get(),
+ LiteralUtil::CreateR0<int64>(1).get()});
+ auto inner_tuple1 =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<int64>(3).get()});
+ auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()})));
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ inner_tuple0->shape(), tuple_constant, 0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // Only the element buffers of the tuple constant which are pointed to by
+ // the GetTupleElement instruction should be liveout.
+ EXPECT_FALSE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{})));
+ EXPECT_TRUE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{0})));
+ EXPECT_TRUE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{0, 0})));
+ EXPECT_TRUE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{0, 1})));
+ EXPECT_FALSE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{1})));
+ EXPECT_FALSE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
+ EXPECT_FALSE(liveness->MaybeLiveOut(
+ GetBuffer(*liveness, tuple_constant, /*index=*/{1, 0})));
+}
+
+TEST_F(BufferLivenessTest, IndependentTupleElements) {
+ auto builder = HloComputation::Builder(TestName());
+ // Create param0 Tuple.
+ auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(S32, {4})}),
+ "param0"));
+ // Create independent computations for each tuple elememt.
+
+ // Tuple element0 computation:
+ // Add(GetTupleElement(tuple_param0, 0), const0)
+ auto tuple_element0_shape =
+ ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
+ auto tuple_element0 =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ tuple_element0_shape, tuple_param0, 0));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
+
+ // Tuple element1 computation:
+ // Add(GetTupleElement(tuple_param0, 1), const1)
+ auto tuple_element1_shape =
+ ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
+ auto tuple_element1 =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ tuple_element1_shape, tuple_param0, 1));
+ auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f})));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1));
+
+ // Create output tuple.
+ auto tuple_root =
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // We compare tuple element pairs that are input/output to the computation:
+ // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
+ // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
+
+ // Tuple output element 'add0' does not depend on input 'tuple_element1'.
+ // Tuple output element 'add1' does not depend on input 'tuple_element0'.
+
+ // Both element pair does not interfere, because there is no other dependency
+ // on the pairs tuple input element, and so liveness can compute that all
+ // users of the input tuple element execute before the associated output
+ // tuple element.
+ EXPECT_FALSE(
+ TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
+ EXPECT_FALSE(
+ TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
+}
+
+TEST_F(BufferLivenessTest, DependentTupleElements) {
+ auto builder = HloComputation::Builder(TestName());
+ // Create param0 Tuple.
+ auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {8}), ShapeUtil::MakeShape(F32, {8})}),
+ "param0"));
+ // Create dependent computations for each tuple elememt.
+
+ // Tuple element0 computation:
+ // Add(GetTupleElement(tuple_param0, 0), const0)
+ auto tuple_element0_shape =
+ ShapeUtil::GetSubshape(tuple_param0->shape(), {0});
+ auto tuple_element0 =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ tuple_element0_shape, tuple_param0, 0));
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0));
+
+ // Tuple element1 computation:
+ // Add(GetTupleElement(tuple_param0, 0), GetTupleElement(tuple_param0, 1))
+ auto tuple_element1_shape =
+ ShapeUtil::GetSubshape(tuple_param0->shape(), {1});
+ auto tuple_element1 =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ tuple_element1_shape, tuple_param0, 1));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ tuple_element1_shape, HloOpcode::kAdd, tuple_element0, tuple_element1));
+
+ // Create output tuple.
+ auto tuple_root =
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ auto liveness =
+ BufferLiveness::Run(module.get(),
+ MakeUnique<DependencyHloOrdering>(module.get()))
+ .ConsumeValueOrDie();
+
+ // We compare tuple element pairs that are input/output to the computation:
+ // 1) (input_tuple_element, output_tuple_element) = ('tuple_element0', 'add0')
+ // 2) (input_tuple_element, output_tuple_element) = ('tuple_element1', 'add1')
+
+ // The first tuple element pair output 'add0', has no dependency on second
+ // tuple element pairs input 'tuple_element1'.
+
+ // The second tuple element pair output 'add1', has a dependency on first
+ // tuple element pairs input 'tuple_element0'.
+
+ // The first tuple element pair does interfere, because liveness cannot
+ // compute that all references to 'tuple_element0' are executed before 'add0'
+ // (because of the depenency of 'add1' on 'tuple_element0').
+ EXPECT_TRUE(
+ TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {0}));
+
+ // The second tuple element pair does not interfere, because there is no
+ // other dependency on 'tuple_element1', and so liveness can compute that
+ // all users execute before 'add1'.
+ EXPECT_FALSE(
+ TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}));
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc
new file mode 100644
index 0000000000..b3784c36ff
--- /dev/null
+++ b/tensorflow/compiler/xla/service/channel_tracker.cc
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/channel_tracker.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+ChannelTracker::ChannelTracker() : next_channel_(1) {}
+
+ChannelHandle ChannelTracker::NewChannel() {
+ tensorflow::mutex_lock lock(channel_mutex_);
+
+ // Create a new channel handle with a unique value.
+ const ChannelHandle new_handle = AllocateHandle();
+
+ // Register a channel object associated with the handle.
+ Channel channel;
+ channel.has_sender = false;
+ channel.receiver_count = 0;
+ opaque_to_channel_[new_handle.handle()] = channel;
+
+ return new_handle;
+}
+
+Status ChannelTracker::RegisterSend(const ChannelHandle& handle) {
+ tensorflow::mutex_lock lock(channel_mutex_);
+ return RegisterSendInternal(handle);
+}
+
+Status ChannelTracker::RegisterRecv(const ChannelHandle& handle) {
+ tensorflow::mutex_lock lock(channel_mutex_);
+ return RegisterRecvInternal(handle);
+}
+
+ChannelHandle ChannelTracker::AllocateHandle() {
+ int64 handle_value = next_channel_++;
+ ChannelHandle result;
+ result.set_handle(handle_value);
+ return result;
+}
+
+Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) {
+ if (opaque_to_channel_.count(handle.handle()) == 0) {
+ return NotFound("channel handle not found: %lld", handle.handle());
+ }
+ Channel& channel = opaque_to_channel_[handle.handle()];
+ if (channel.has_sender) {
+ return FailedPrecondition("channel handle is already used by a sender");
+ }
+ channel.has_sender = true;
+ return Status::OK();
+}
+
+Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) {
+ if (opaque_to_channel_.count(handle.handle()) == 0) {
+ return NotFound("channel handle not found: %lld", handle.handle());
+ }
+ Channel& channel = opaque_to_channel_[handle.handle()];
+ // TODO(b/33942691): Allow more than 1 receivers for broadcast.
+ if (channel.receiver_count >= 1) {
+ return FailedPrecondition("channel handle is already used by a receiver");
+ }
+ channel.receiver_count += 1;
+ return Status::OK();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/channel_tracker.h b/tensorflow/compiler/xla/service/channel_tracker.h
new file mode 100644
index 0000000000..c7763f2ca3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/channel_tracker.h
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_
+
+#include <map>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/user_computation.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Tracks channels between computations in the XLA service. Channels
+// are associated with a unique handle and can be resolved from the handle for
+// later use.
+//
+// TODO(b/34027823): Destruct channels when all the associated computations that
+// communicate via each channel are destructed.
+class ChannelTracker {
+ public:
+ ChannelTracker();
+
+ // A struct that keeps the current status of each channel. has_sender and
+ // receiver_count fields are initialized with false and 0 respectively when
+ // the struct is created and are updated by RegisterSend() and RegisterRecev()
+ // as Send or Recv instructions using the channel are requested.
+ struct Channel {
+ bool has_sender;
+ int64 receiver_count;
+ };
+
+ // Creates a new Channel object and returns the corresponding
+ // ChannelHandle for it.
+ ChannelHandle NewChannel();
+
+ // Informs that the given channel handle is used for a Send operation.
+ // Returns an error status if the handle is already used by another Send.
+ Status RegisterSend(const ChannelHandle& handle);
+
+ // Informs that the given channel handle is used for a Recv operation.
+ // Returns an error status if the handle is already used by another Recv.
+ Status RegisterRecv(const ChannelHandle& handle);
+
+ private:
+ // Bumps the next_channel_ number and returns the allocated number
+ // wrapped in a ChannelHandle.
+ ChannelHandle AllocateHandle() EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_);
+
+ Status RegisterSendInternal(const ChannelHandle& handle)
+ EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_);
+
+ Status RegisterRecvInternal(const ChannelHandle& handle)
+ EXCLUSIVE_LOCKS_REQUIRED(channel_mutex_);
+
+ // Guards the channel mapping.
+ tensorflow::mutex channel_mutex_;
+
+ // The next sequence number to assign to a channel.
+ int64 next_channel_ GUARDED_BY(channel_mutex_);
+
+ // Mapping from ChannelHandle value to the corresponding registered
+ // Channel object.
+ std::map<int64, Channel> opaque_to_channel_ GUARDED_BY(channel_mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ChannelTracker);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHANNEL_TRACKER_H_
diff --git a/tensorflow/compiler/xla/service/compilation_cache.cc b/tensorflow/compiler/xla/service/compilation_cache.cc
new file mode 100644
index 0000000000..b16907da9e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compilation_cache.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/compilation_cache.h"
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+std::shared_ptr<Executable> CompilationCache::Insert(
+ std::unique_ptr<Executable> executable,
+ const HloModuleConfig& module_config) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ CacheKey key =
+ BuildKey(executable->entry_computation_handle(), module_config);
+ VLOG(2) << "inserting cache key: " << key;
+ if (cache_.count(key) == 0) {
+ cache_.emplace(key, std::move(executable));
+ } else {
+ // Executable already exists in the cache. This can happen if two Execute
+ // calls for a new computation are received simultaneously by the
+ // service. In this case, we discard the Executable given as a parameter and
+ // return what is in the cache. This is necessary because the service relies
+ // on the cache to keep ownership of the Executable. We only want to store
+ // one Executable for a given computation version and we can't discard the
+ // executable which is in the cache because it may be in use.
+ executable.reset();
+ }
+ return cache_.at(key);
+}
+
+std::shared_ptr<Executable> CompilationCache::LookUp(
+ const VersionedComputationHandle& versioned_handle,
+ const HloModuleConfig& module_config) const {
+ tensorflow::mutex_lock lock(mutex_);
+
+ CacheKey key = BuildKey(versioned_handle, module_config);
+ VLOG(2) << "looking up cache key: " << key;
+ if (cache_.count(key) == 0) {
+ VLOG(2) << "cache key not found: " << key;
+ return nullptr;
+ } else {
+ std::shared_ptr<Executable> result = cache_.at(key);
+ VLOG(2) << "hit executable with module config: "
+ << result->module_config().compilation_cache_key();
+ return result;
+ }
+}
+
+CompilationCache::CacheKey CompilationCache::BuildKey(
+ const VersionedComputationHandle& versioned_handle,
+ const HloModuleConfig& module_config) const {
+ // The computation shape is represented entirely by its ProgramShape member,
+ // so just serialize the proto as part of the key.
+ return tensorflow::strings::StrCat(versioned_handle.handle.handle(), "::",
+ versioned_handle.version, "::",
+ module_config.compilation_cache_key());
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compilation_cache.h b/tensorflow/compiler/xla/service/compilation_cache.h
new file mode 100644
index 0000000000..09989726ae
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compilation_cache.h
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
+
+#include <map>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+
+// A cache which stores Executables indexed by computation handle and version.
+class CompilationCache {
+ public:
+ CompilationCache() {}
+
+ // Insert the given Executable into the cache. Return a bare Executable
+ // pointer for the caller to use. Note: the returned pointer will *not* be the
+ // same as the given unique pointer if the computation already exists in the
+ // cache. See comments in the .cc implementation for details of this case.
+ //
+ // module_config is provided by the caller, instead of being taken from the
+ // executable, so that we can insert keys into the compilation cache that are
+ // devoid of layout (where XLA gets to choose what layout to compile).
+ //
+ // A shared_ptr is returned so the caller can keep the Executable from being
+ // destructed in the event that the Executable is evicted from the
+ // computation cache (and the cache's shared_ptr to the Executable is
+ // destructed).
+ std::shared_ptr<Executable> Insert(std::unique_ptr<Executable> executable,
+ const HloModuleConfig& module_config);
+
+ // Lookup the Executable for the specified versioned computation in the cache.
+ // Return a shared_ptr to the Executable if it exists in the cache. Return
+ // nullptr otherwise.
+ std::shared_ptr<Executable> LookUp(
+ const VersionedComputationHandle& versioned_handle,
+ const HloModuleConfig& module_config) const;
+
+ protected:
+ mutable tensorflow::mutex mutex_;
+
+ // Map from versioned handle with program layout to Executable built
+ // for that computation version and program layout.
+ using CacheKey = string;
+
+ CacheKey BuildKey(const VersionedComputationHandle& versioned_handle,
+ const HloModuleConfig& module_config) const;
+ std::map<CacheKey, std::shared_ptr<Executable>> cache_ GUARDED_BY(mutex_);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CompilationCache);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILATION_CACHE_H_
diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc
new file mode 100644
index 0000000000..f71b2b6b9c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compiler.cc
@@ -0,0 +1,96 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+/* static */ tensorflow::mutex* Compiler::platform_compiler_mutex_;
+
+/* static */ void Compiler::LazyInitMutex() {
+ static std::once_flag mutex_init_flag;
+ std::call_once(mutex_init_flag, []() {
+ Compiler::platform_compiler_mutex_ = new tensorflow::mutex;
+ });
+}
+
+/* static */ std::map<perftools::gputools::Platform::Id,
+ Compiler::CompilerFactory>*
+Compiler::GetPlatformCompilerFactories() {
+ static auto* r =
+ new std::map<perftools::gputools::Platform::Id, CompilerFactory>;
+ return r;
+}
+
+/* static */
+std::map<perftools::gputools::Platform::Id, std::unique_ptr<Compiler>>*
+Compiler::GetPlatformCompilers() {
+ static auto* r = new std::map<perftools::gputools::Platform::Id,
+ std::unique_ptr<Compiler>>;
+ return r;
+}
+
+/* static */ void Compiler::RegisterCompilerFactory(
+ se::Platform::Id platform_id,
+ std::function<std::unique_ptr<Compiler>()> compiler_factory) {
+ LazyInitMutex();
+ tensorflow::mutex_lock lock(*platform_compiler_mutex_);
+ auto* factories = GetPlatformCompilerFactories();
+ CHECK(factories->find(platform_id) == factories->end());
+ (*factories)[platform_id] = std::move(compiler_factory);
+}
+
+/* static */ StatusOr<Compiler*> Compiler::GetForPlatform(
+ const se::Platform* platform) {
+ LazyInitMutex();
+ tensorflow::mutex_lock lock(*platform_compiler_mutex_);
+
+ auto* compilers = GetPlatformCompilers();
+ // See if we already instantiated a compiler for this platform.
+ {
+ auto it = compilers->find(platform->id());
+ if (it != compilers->end()) {
+ return it->second.get();
+ }
+
+ // If not, we just fall through to try to create one with a registered
+ // factory.
+ }
+
+ auto* factories = GetPlatformCompilerFactories();
+ auto it = factories->find(platform->id());
+ if (it == factories->end()) {
+ return NotFound(
+ "could not find registered compiler for platform %s -- check "
+ "target linkage",
+ platform->Name().c_str());
+ }
+
+ // And then we invoke the factory, placing the result into the mapping.
+ compilers->insert(std::make_pair(platform->id(), it->second()));
+ return compilers->at(platform->id()).get();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h
new file mode 100644
index 0000000000..632081a747
--- /dev/null
+++ b/tensorflow/compiler/xla/service/compiler.h
@@ -0,0 +1,172 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// The compiler API is used by the XLA service to generate executables that
+// run on a given platform. This is a registry and abstract interface, for
+// pluggability by the various platforms.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+
+// The following types are used for ahead of time compilation.
+
+// Contains the object file data created as a result of ahead-of-time
+// compuation.
+using ObjectFileData = std::vector<char>;
+
+// Contains the buffer sizes information needed to allocate buffers to execute
+// an ahead-of-time computation. Entries which contain -1 designate a parameter
+// which should be skipped over during allocation.
+using BufferSizes = std::vector<int64>;
+
+// Abstract superclass describing the result of an ahead-of-time compilation.
+class AotCompilationResult {
+ public:
+ AotCompilationResult(const AotCompilationResult&) = delete;
+ AotCompilationResult& operator=(AotCompilationResult const&) = delete;
+
+ virtual ~AotCompilationResult() = default;
+
+ protected:
+ AotCompilationResult() = default;
+};
+
+// Abstract superclass describing options to an ahead-of-time compilation.
+class AotCompilationOptions {
+ public:
+ AotCompilationOptions(const AotCompilationOptions&) = delete;
+ AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
+
+ virtual ~AotCompilationOptions() = default;
+
+ // Returns the ID of the platform to which these options apply.
+ virtual perftools::gputools::Platform::Id PlatformId() const = 0;
+
+ protected:
+ AotCompilationOptions() = default;
+};
+
+// Abstract compiler interface that is subclassed for compilation on a
+// particular platform.
+//
+// The compiler ties together high level optimization (HLO) and low level
+// optimization (LLO) / codegen (CG) to generate efficient executables for the
+// target platform.
+//
+// The platform-based compiler singletons are registered via module initializers
+// in their corresponding XLA compiler libraries, and are registered via the
+// RegisterCompilerFactory API below.
+//
+// Thread-safety: subclasses of Compiler must be thread-safe, as multiple
+// XLA clients may be requesting compilation concurrently for a given
+// platform.
+class Compiler {
+ public:
+ // Callback signature used to dump the HLO graph during compilation.
+ // Different compiler backends will call this as they please, providing
+ // a view of the HLO at different points in compilation -- context for the
+ // dump is indicated by the label string.
+ using HloDumper =
+ std::function<void(const HloModule& module, const string& label)>;
+
+ virtual ~Compiler() {}
+
+ // Returns the ID of the platform that this compiler targets.
+ virtual perftools::gputools::Platform::Id PlatformId() const = 0;
+
+ // Compiles the HLO module for execution on a device given by the executor,
+ // and returns an executable object or an error status. Takes ownership of the
+ // HLO module and is free to transform it.
+ //
+ // The compiler may optionally specialize to the individual device
+ // (not just type of device) indicated by the executor.
+ //
+ // TODO(leary) will need to update this API when a single computation can run
+ // across multiple devices simultaneously.
+ virtual StatusOr<std::unique_ptr<Executable>> Compile(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ perftools::gputools::StreamExecutor* executor) = 0;
+
+ // Compiles a set of HLO modules that can run in parallel, potentially
+ // communicating data between the modules, and returns a corresponding
+ // sequence of executable objects.
+ virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_module,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_config,
+ HloDumper dump_hlo,
+ std::vector<perftools::gputools::StreamExecutor*> stream_exec) = 0;
+
+ // Compiles the HLO module for ahead-of-time execution. This is intended for
+ // use in static compilation.
+ virtual StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ const AotCompilationOptions& options) = 0;
+
+ /////
+ // The Compiler class also serves as a point to register compiler objects
+ // for the various platforms.
+
+ using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
+
+ // Registers the compiler singleton for the platform. This is assumed to
+ // be a singleton, so no ownership is transferred.
+ //
+ // Precondition: a platform kind must not be registered more than once.
+ static void RegisterCompilerFactory(
+ perftools::gputools::Platform::Id platform_id,
+ CompilerFactory compiler_factory);
+
+ // Returns the compiler singleton pointer if it is available for the given
+ // platform, or an error status if it is not.
+ static StatusOr<Compiler*> GetForPlatform(
+ const perftools::gputools::Platform* platform);
+
+ private:
+ // Mutex that guards the platform-compiler map.
+ static tensorflow::mutex* platform_compiler_mutex_;
+ static void LazyInitMutex();
+
+ // Map from platform kind to compiler factory.
+ static std::map<perftools::gputools::Platform::Id, CompilerFactory>*
+ GetPlatformCompilerFactories();
+
+ // Map from platform kind to compiler instance, if we made one already (based
+ // on the factories above).
+ static std::map<perftools::gputools::Platform::Id, std::unique_ptr<Compiler>>*
+ GetPlatformCompilers();
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPILER_H_
diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc
new file mode 100644
index 0000000000..d2d4f14fce
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_layout.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+
+#include <algorithm>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+
+ComputationLayout::ComputationLayout(const ProgramShape& program_shape)
+ : result_layout_(program_shape.result()) {
+ for (auto& shape : program_shape.parameters()) {
+ parameter_layouts_.emplace_back(shape);
+ }
+ SetToDefaultLayout();
+}
+
+void ComputationLayout::SetToDefaultLayout() {
+ for (auto& parameter_layout : parameter_layouts_) {
+ parameter_layout.SetToDefaultLayout();
+ }
+ result_layout_.SetToDefaultLayout();
+}
+
+bool ComputationLayout::LayoutIsSet() const {
+ return std::all_of(parameter_layouts_.begin(), parameter_layouts_.end(),
+ [](const ShapeLayout& s) { return s.LayoutIsSet(); }) &&
+ result_layout_.LayoutIsSet();
+}
+
+string ComputationLayout::ToString() const {
+ std::vector<string> params;
+ for (auto& param_layout : parameter_layouts_) {
+ params.push_back(param_layout.ToString());
+ }
+ return tensorflow::strings::StrCat("(",
+ tensorflow::str_util::Join(params, ", "),
+ ") => ", result_layout_.ToString());
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h
new file mode 100644
index 0000000000..80e102411c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_layout.h
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_LAYOUT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_LAYOUT_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Class which contains the layouts of the parameters and results of a
+// computation. The layouts are stored as ShapeLayouts with immutable shapes and
+// mutable layouts.
+class ComputationLayout {
+ public:
+ // Constructs a ComputationLayout from a ProgramShape. The layouts of the
+ // parameters and results are set to the default layout. Layouts in the
+ // ProgramShape are ignored.
+ explicit ComputationLayout(const ProgramShape& program_shape);
+
+ // Returns the layout of a particular parameter.
+ const ShapeLayout& parameter_layout(int64 param_no) const {
+ return parameter_layouts_[param_no];
+ }
+ ShapeLayout* mutable_parameter_layout(int64 param_no) {
+ return &parameter_layouts_[param_no];
+ }
+
+ // Returns the number of parameters in the computation.
+ int parameter_count() const { return parameter_layouts_.size(); }
+
+ // Returns the ShapeLayouts of the parameters of the computation.
+ const std::vector<ShapeLayout>& parameter_layouts() const {
+ return parameter_layouts_;
+ }
+
+ // Returns the ShapeLayout of a result of the computation.
+ const ShapeLayout& result_layout() const { return result_layout_; }
+ ShapeLayout* mutable_result_layout() { return &result_layout_; }
+
+ // Returns the shape of the particular parameter or result of the computation
+ // with layout.
+ const Shape& parameter_shape(int64 param_no) const {
+ return parameter_layouts_[param_no].shape();
+ }
+ const Shape& result_shape() const { return result_layout_.shape(); }
+
+ // Sets layouts of all parameters and the result to the default layout.
+ void SetToDefaultLayout();
+
+ // Returns true if all layouts (parameters and result) have been set.
+ bool LayoutIsSet() const;
+
+ // Returns a string representation of this object.
+ string ToString() const;
+
+ private:
+ std::vector<ShapeLayout> parameter_layouts_;
+ ShapeLayout result_layout_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_LAYOUT_H_
diff --git a/tensorflow/compiler/xla/service/computation_tracker.cc b/tensorflow/compiler/xla/service/computation_tracker.cc
new file mode 100644
index 0000000000..281277bed5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_tracker.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+
+#include <list>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+ComputationTracker::ComputationTracker() : next_computation_(1) {}
+
+ComputationHandle ComputationTracker::NewComputation(
+ const string& computation_name) {
+ tensorflow::mutex_lock lock(computation_mutex_);
+ ComputationHandle computation_handle;
+ int64 handle_value = next_computation_++;
+ computation_handle.set_handle(handle_value);
+ opaque_to_computation_[handle_value] =
+ MakeUnique<UserComputation>(computation_name, computation_handle);
+ return computation_handle;
+}
+
+StatusOr<ComputationHandle> ComputationTracker::LoadSessionModule(
+ const SessionModule& session_module) {
+ tensorflow::mutex_lock lock(computation_mutex_);
+
+ // For each embedded computation, create a new computation based on its
+ // serialized data, and place the mapping from the old computation handle to
+ // the new computation handle.
+ std::map<int64, ComputationHandle> old_to_new;
+ for (const SessionComputation& computation :
+ session_module.embedded_computations()) {
+ const int64 old_handle = computation.computation_handle().handle();
+ TF_ASSIGN_OR_RETURN(old_to_new[old_handle],
+ LoadSessionComputation(computation, &old_to_new));
+ }
+
+ // Finally, place the entry computation in the tracker with all of the
+ // remappings populated from the above.
+ const int64 old_handle = session_module.entry().computation_handle().handle();
+ TF_ASSIGN_OR_RETURN(
+ old_to_new[old_handle],
+ LoadSessionComputation(session_module.entry(), &old_to_new));
+ return old_to_new[old_handle];
+}
+
+StatusOr<std::unique_ptr<SessionModule>>
+ComputationTracker::SnapshotComputation(const ComputationHandle& computation) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation, Resolve(computation));
+ const VersionedComputationHandle entry_versioned_handle =
+ user_computation->GetVersionedHandle();
+ std::set<VersionedComputationHandle> visited;
+ std::list<VersionedComputationHandle> post_order;
+ {
+ tensorflow::mutex_lock lock(computation_mutex_);
+ ComputeComputationPostOrder(entry_versioned_handle, &visited, &post_order);
+ }
+ auto session_module = MakeUnique<SessionModule>();
+ *session_module->mutable_entry() =
+ Resolve(entry_versioned_handle.handle)
+ .ValueOrDie()
+ ->CloneSessionComputation(entry_versioned_handle.version);
+ for (auto it = ++post_order.rbegin(); it != post_order.rend(); ++it) {
+ *session_module->add_embedded_computations() =
+ Resolve(it->handle).ValueOrDie()->CloneSessionComputation(it->version);
+ }
+ return std::move(session_module);
+}
+
+StatusOr<UserComputation*> ComputationTracker::Resolve(
+ const ComputationHandle& computation) const {
+ tensorflow::mutex_lock lock(computation_mutex_);
+ return ResolveInternal(computation);
+}
+
+ComputationHandle ComputationTracker::AllocateHandle() {
+ int64 handle_value = next_computation_++;
+ ComputationHandle result;
+ result.set_handle(handle_value);
+ return result;
+}
+
+StatusOr<ComputationHandle> ComputationTracker::LoadSessionComputation(
+ const SessionComputation& session_computation,
+ std::map<int64, ComputationHandle>* old_to_new) {
+ TF_RET_CHECK(old_to_new != nullptr);
+ const ComputationHandle new_handle = AllocateHandle();
+ (*old_to_new)[session_computation.computation_handle().handle()] = new_handle;
+ TF_ASSIGN_OR_RETURN(opaque_to_computation_[new_handle.handle()],
+ UserComputation::MakeWithRemapping(
+ session_computation, new_handle, *old_to_new));
+ return new_handle;
+}
+
+StatusOr<UserComputation*> ComputationTracker::ResolveInternal(
+ const ComputationHandle& computation) const {
+ auto it = opaque_to_computation_.find(computation.handle());
+ if (it == opaque_to_computation_.end()) {
+ return NotFound("computation handle not found: %lld", computation.handle());
+ }
+ UserComputation* user_computation = it->second.get();
+ return user_computation;
+}
+
+void ComputationTracker::ComputeComputationPostOrder(
+ const VersionedComputationHandle& versioned_handle,
+ std::set<VersionedComputationHandle>* visited,
+ std::list<VersionedComputationHandle>* post_order) const {
+ if (visited->count(versioned_handle) > 0) {
+ DCHECK_EQ(1, visited->count(versioned_handle));
+ return;
+ }
+
+ UserComputation* computation =
+ ResolveInternal(versioned_handle.handle).ValueOrDie();
+ std::vector<VersionedComputationHandle> embedded_handles =
+ computation->GetEmbeddedComputations(versioned_handle.version);
+
+ for (const auto& embedded_handle : embedded_handles) {
+ ComputeComputationPostOrder(embedded_handle, visited, post_order);
+ }
+
+ visited->insert(versioned_handle);
+ post_order->push_back(versioned_handle);
+ return;
+}
+
+StatusOr<std::unique_ptr<HloModule>> ComputationTracker::BuildHloModule(
+ const VersionedComputationHandle& entry_handle,
+ bool include_unused_parameters) const {
+ tensorflow::mutex_lock lock(computation_mutex_);
+
+ TF_ASSIGN_OR_RETURN(UserComputation * entry_computation,
+ ResolveInternal(entry_handle.handle));
+
+ // Build a topological sort of the entry and any embedded computations as a
+ // list. The root of the computation will be the last element in the list.
+ std::set<VersionedComputationHandle> visited;
+ std::list<VersionedComputationHandle> post_order;
+ ComputeComputationPostOrder(entry_handle, &visited, &post_order);
+
+ // Map from ComputationHandle value and computation version to HloComputation.
+ std::map<VersionedComputationHandle, HloComputation*> hlo_computations;
+
+ // The resolver lambda resolves VersionedHandles to embedded
+ // HloComputation*. This is required by UserComputation::BuildHloComputation
+ // when lowering calling operations (map, reduce etc).
+ auto resolver = [&hlo_computations](
+ const VersionedComputationHandle& versioned_handle) -> HloComputation* {
+ CHECK_GT(hlo_computations.count(versioned_handle), 0);
+ return hlo_computations.at(versioned_handle);
+ };
+
+ string module_name =
+ tensorflow::strings::StrCat(entry_computation->name(), "_module");
+ auto module = MakeUnique<HloModule>(module_name, entry_handle);
+ for (auto versioned_handle : post_order) {
+ UserComputation* computation =
+ ResolveInternal(versioned_handle.handle).ValueOrDie();
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloComputation> hlo_computation,
+ computation->BuildHloComputation(versioned_handle.version, resolver,
+ include_unused_parameters));
+
+ // Add the newly created computation to VersionedHandle-to-HloComputation
+ // map.
+ DCHECK_EQ(0, hlo_computations.count(versioned_handle));
+ hlo_computations[versioned_handle] = hlo_computation.get();
+
+ if (computation == entry_computation) {
+ module->AddEntryComputation(std::move(hlo_computation));
+ } else {
+ module->AddEmbeddedComputation(std::move(hlo_computation));
+ }
+ }
+
+ return std::move(module);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/computation_tracker.h b/tensorflow/compiler/xla/service/computation_tracker.h
new file mode 100644
index 0000000000..7d0660d7f6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/computation_tracker.h
@@ -0,0 +1,139 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
+
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/user_computation.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Tracks computations for the XLA service; computations can be registered
+// with a UserComputation instance and can be resolved from a handle for later
+// use.
+//
+// This class is also capable of serializing/deserializing computations that it
+// tracks (and to serialize properly you need to serialize all referred-to
+// computations as well).
+class ComputationTracker {
+ public:
+ ComputationTracker();
+
+ // Creates a new UserComputation object and returns the corresponding
+ // ComputationHandle for it.
+ //
+ // Precondition: user_computation is not already present in the map.
+ ComputationHandle NewComputation(const string& computation_name);
+
+ // Restores session data for a computation that has been serialized, and
+ // allocates a new computation handle for it.
+ StatusOr<ComputationHandle> LoadSessionModule(
+ const SessionModule& session_module);
+
+ // Snapshots a computation (referenced by the provided handle) at its latest
+ // version, returning a module where it is the entry, and any referred-to
+ // computations are entrained as "embedded" (non-entry) computations.
+ StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation(
+ const ComputationHandle& computation);
+
+ // Resolves a ComputationHandle to a UserComputation that is present in the
+ // map.
+ StatusOr<UserComputation*> Resolve(
+ const ComputationHandle& computation) const;
+
+ // Builds an HLO module using the specified computation as the entry. The
+ // module will include the entry computation as well as all computations which
+ // are called directly or indirectly from the entry computation via operations
+ // like "map". If include_unused_parameters is true, then all parameters are
+ // lowered to HLO instructions even if they are not used. This ensures the
+ // entry HloComputation has the same program shape (ProgramShape) as the entry
+ // UserComputation.
+ StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
+ const VersionedComputationHandle& entry_handle,
+ bool include_unused_parameters = true) const;
+
+ private:
+ // Bumps the next_computation_ number and returns the allocated number wrapped
+ // in a ComputationHandle.
+ ComputationHandle AllocateHandle()
+ EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
+
+ // Loads a session computation into a UserComputation, registers it, and
+ // returns the computation handle of the registered computation. If old_to_new
+ // is provided, it is used for remapping references to computations present in
+ // session_computation.
+ //
+ // old_to_new will be updated with the mapping from session_computation's old
+ // handle to the returned handle value, and may not be null.
+ StatusOr<ComputationHandle> LoadSessionComputation(
+ const SessionComputation& session_computation,
+ std::map<int64, ComputationHandle>* old_to_new)
+ EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
+
+ // Internal implementation of Resolve method which requires, but does not
+ // acquire the mutex.
+ StatusOr<UserComputation*> ResolveInternal(
+ const ComputationHandle& computation) const
+ EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
+
+ // Builds a post order sort of a computation ("entry") and all of its embedded
+ // computations including all transitively embedded computations. An embedded
+ // computation (the callee) will always appear in the sort before the
+ // computation which calls the embedded computation (the caller). Necessarily,
+ // the entry computation is the last element in the sort. visited and
+ // post_order should be empty when calling. post_order contains the post order
+ // sort when the function return.
+ void ComputeComputationPostOrder(
+ const VersionedComputationHandle& versioned_handle,
+ std::set<VersionedComputationHandle>* visited,
+ std::list<VersionedComputationHandle>* post_order) const
+ EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
+
+ // Guards the computation mapping. Marked mutable so that the Resolve method
+ // can remain const; Resolve does't really modify the tracker in any way, but
+ // it has to lock the mutex for safety.
+ mutable tensorflow::mutex computation_mutex_;
+
+ // The next sequence number to assign to a computation, guarded by the same
+ // mutex as the mapping as they'll be mutated at the same time.
+ int64 next_computation_ GUARDED_BY(computation_mutex_);
+
+ // Mapping from ComputationHandle value to the corresponding registered
+ // UserComputation object.
+ std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_
+ GUARDED_BY(computation_mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
new file mode 100644
index 0000000000..dbf5085c1e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -0,0 +1,439 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
+
+#include <memory>
+#include <set>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+namespace {
+
+// InstructionCopier encapsulates indices at which to copy 'instruction'.
+// All 'instruction' users in 'copy_users' are updated to use the copy.
+//
+// Instruction copies are generated in two phases:
+// 1) Recording buffer indices at which 'instruction' requires copies (i.e.
+// setting 'indices_to_copy_[index]'=true).
+// 2) Inserting kCopy instructions based on indices recorded in phase 1).
+// *) Array instructions are copied by inserting a single kCopy instruction.
+// *) Tuple-shaped instructions are copied by recursively expanding tuples
+// (and tuple-shaped elements), and inserting kCopy instructions for any
+// tuple elements which require a copy. As the recursion unwinds, new tuple
+// instructions are added to gather the copied (and uncopied) references
+// into the output tuple (i.e. the copy of the tuple-shaped instruction).
+//
+// Example two-element tuple with one element that needs a copy:
+//
+// Tuple // instruction
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy |
+// \ /
+// Tuple // copied-instruction
+//
+class InstructionCopier {
+ public:
+ InstructionCopier(const bool init_value, HloInstruction* instruction,
+ const std::vector<HloInstruction*>& copy_users);
+
+ // Returns true if all recorded indices are false (returns true otherwise).
+ bool HasAllIndicesFalse() const;
+
+ // Records instruction buffer indices which point-to a Parameter or Constant.
+ tensorflow::Status RecordIndicesWhichPointToParamOrConstant(
+ const TuplePointsToAnalysis& points_to_analysis);
+
+ // Records instruction buffer indices to copy which are necessary to ensure:
+ // *) PointsToSet of 'instruction_' is unambiguous and distinct.
+ // *) No liveness interference between 'instruction_' and 'other_instruction'.
+ tensorflow::Status RecordIndicesToCopyForColocatingBuffers(
+ BufferLiveness* liveness, HloInstruction* other_instruction);
+
+ // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy',
+ // and replaces all uses for instructions in 'copy_users_' with copy.
+ // Returns the instruction which is a copy 'instruction'.
+ HloInstruction* Copy();
+
+ HloInstruction* instruction() { return instruction_; }
+
+ const std::vector<HloInstruction*>& copy_users() const { return copy_users_; }
+
+ private:
+ // Records instruction buffer indices which have ambiguous or non-distinct
+ // points-to sets.
+ tensorflow::Status RecordAmbiguousOrNonDistinctIndices(
+ const TuplePointsToAnalysis& points_to_analysis);
+
+ // Records instruction buffer indices which have interferring live ranges
+ // with 'other_instruction' buffers at same index.
+ tensorflow::Status RecordIndicesWhichInterfereWithOtherInstruction(
+ BufferLiveness* liveness, HloInstruction* other_instruction);
+
+ // Recursively inserts copies of 'instruction' tuple elements at indices
+ // specified in 'indices_to_copy', and returns the copy of 'instruction'.
+ HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index);
+
+ void RecordIndex(const ShapeIndex& index) {
+ *indices_to_copy_.mutable_element(index) = true;
+ }
+
+ HloInstruction* instruction_;
+ std::vector<HloInstruction*> copy_users_;
+ ShapeTree<bool> indices_to_copy_;
+};
+
+InstructionCopier::InstructionCopier(
+ const bool init_value, HloInstruction* instruction,
+ const std::vector<HloInstruction*>& copy_users)
+ : instruction_(instruction),
+ copy_users_(copy_users),
+ indices_to_copy_(instruction->shape(), init_value) {}
+
+bool InstructionCopier::HasAllIndicesFalse() const {
+ bool all_indices_false = true;
+ TF_CHECK_OK(indices_to_copy_.ForEachElement([&all_indices_false](
+ const ShapeIndex& /*index*/, bool /*is_leaf*/, const bool& data) {
+ if (data) all_indices_false = false;
+ return tensorflow::Status::OK();
+ }));
+ return all_indices_false;
+}
+
+tensorflow::Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant(
+ const TuplePointsToAnalysis& points_to_analysis) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction_);
+ // Shallow copy the instruction if the points-to set of the top-level
+ // buffer is ambiguous. This is necessary because the backends must know
+ // statically what the top-level buffer of the result is.
+ if (points_to.element(/*index=*/{}).size() > 1) {
+ RecordIndex({});
+ }
+
+ // Multiple buffers within a parameter/constant may be live out, so collect
+ // a set of indices at which to copy first.
+ TF_RETURN_IF_ERROR(points_to.ForEachElement([this](
+ const ShapeIndex& index, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& buffers) {
+ for (auto buffer : buffers) {
+ // pointee is the HloInstruction producing the buffer which may be
+ // liveout.
+ HloInstruction* pointee = buffer->instruction();
+ if (pointee->opcode() == HloOpcode::kParameter ||
+ pointee->opcode() == HloOpcode::kConstant) {
+ VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
+ << " index: " << tensorflow::str_util::Join(index, ",")
+ << " may be live out of computation: " << pointee->ToString();
+ RecordIndex(index);
+ }
+ }
+ return tensorflow::Status::OK();
+ }));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers(
+ BufferLiveness* liveness, HloInstruction* other_instruction) {
+ TF_RETURN_IF_ERROR(
+ RecordAmbiguousOrNonDistinctIndices(liveness->points_to_analysis()));
+ TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction(
+ liveness, other_instruction));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
+ const TuplePointsToAnalysis& points_to_analysis) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction_);
+ // Mapping from LogicalBuffer to index (used to detect non-distinct indices).
+ // TODO(b/32116879) User ShapeIndex here when it is available.
+ std::unordered_map<const LogicalBuffer*, std::vector<ShapeIndex>>
+ buffer_to_source_indices;
+ TF_RETURN_IF_ERROR(points_to.ForEachElement([this, &buffer_to_source_indices](
+ const ShapeIndex& index, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& buffers) {
+ if (buffers.size() > 1) {
+ // Record ambiguous points-to set at 'index'.
+ if (!indices_to_copy_.element(index)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(index, ",")
+ << " with ambiguous points-to set.";
+ RecordIndex(index);
+ }
+ }
+ // For each 'buffer': record a mapping from 'buffer' to 'index'.
+ for (auto& buffer : buffers) {
+ auto it = buffer_to_source_indices.find(buffer);
+ if (it == buffer_to_source_indices.end()) {
+ buffer_to_source_indices.insert({buffer, std::vector<ShapeIndex>()});
+ }
+ buffer_to_source_indices[buffer].push_back(index);
+ }
+ return tensorflow::Status::OK();
+ }));
+
+ // Record all non-distinct indices detected in 'buffer_to_source_indices'.
+ for (auto& buff_to_src : buffer_to_source_indices) {
+ if (buff_to_src.second.size() == 1) {
+ continue;
+ }
+ for (auto& src_index : buff_to_src.second) {
+ // Record non-distinct points-to set at 'src_index'.
+ if (!indices_to_copy_.element(src_index)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(src_index, ",")
+ << " because of non-distinct points-to set.";
+ RecordIndex(src_index);
+ }
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status
+InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
+ BufferLiveness* liveness, HloInstruction* other_instruction) {
+ // Record all buffer indices for 'instruction_', which interfere with
+ // 'other_instruction' at the same index.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape(
+ instruction_->shape(),
+ [this, &liveness, &other_instruction](const Shape& /*subshape*/,
+ const ShapeIndex& index) {
+ if (indices_to_copy_.element(index)) {
+ // Return if previous pass already set index.
+ return tensorflow::Status::OK();
+ }
+ auto& points_to_analysis = liveness->points_to_analysis();
+ // Lookup buffers for 'instruction_' and 'other_instruction'.
+ const std::vector<const LogicalBuffer*> instruction_buffers =
+ points_to_analysis.GetPointsToSet(instruction_).element(index);
+ // If 'instruction_' has ambiguous points-to-set at 'index', it would
+ // have been recorded in a previous pass (and we would have returned
+ // early at the entry to this function). As a result, here we know that
+ // 'instruction_' has just one buffer in its points-to-set.
+ CHECK_EQ(1, instruction_buffers.size());
+ const LogicalBuffer* instruction_buffer = instruction_buffers[0];
+
+ const std::vector<const LogicalBuffer*> other_instruction_buffers =
+ points_to_analysis.GetPointsToSet(other_instruction).element(index);
+ // Do not insert a copy if both instructions point at the same buffer.
+ // This eliminates unnecessary copies of read-only tuple elements.
+ // If 'instruction_' and 'other_instruction' point to the same buffer,
+ // then that buffer is not updated on the path between the two
+ // instructions. Therefore, any other (possibly interference-causing)
+ // users of that buffer from 'other_instruction' will see the same data,
+ // irrespecive of whether we insert a copy of this buffer at
+ // 'instruction_' or not.
+ if (other_instruction_buffers.size() == 1 &&
+ other_instruction_buffers[0]->id() == instruction_buffer->id()) {
+ return tensorflow::Status::OK();
+ }
+ // We cant say anything about the ambiguity of 'other_instruction' at
+ // this point, so we need to check interference between the single
+ // buffer in the points-to set of 'instruction_' and all buffers in
+ // 'other_instruction_buffers'.
+ for (auto& other_buffer : other_instruction_buffers) {
+ if (liveness->MayInterfere(*instruction_buffer, *other_buffer)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(index, ",")
+ << " because of interference with buffer: "
+ << other_buffer->ToString();
+ RecordIndex(index);
+ break;
+ }
+ }
+ return tensorflow::Status::OK();
+ }));
+ return tensorflow::Status::OK();
+}
+
+// Recursively inserts copies of 'instruction' tuple element buffers at
+// indices in 'indices_to_copy_', expanding tuples as needed.
+// TODO(b/31159897) Remove superfluous Tuple->GTE->Tuple expressions.
+HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction,
+ ShapeIndex* index) {
+ std::vector<HloInstruction*> element_copies;
+ const int64 num_tuple_elements =
+ ShapeUtil::TupleElementCount(instruction->shape());
+ for (int64 i = 0; i < num_tuple_elements; ++i) {
+ HloInstruction* gte = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i));
+ HloInstruction* element_copy;
+ index->push_back(i);
+ if (ShapeUtil::IsTuple(gte->shape())) {
+ element_copy = CopyTuple(gte, index);
+ } else {
+ if (indices_to_copy_.element(*index)) {
+ element_copy = gte->parent()->AddInstruction(
+ HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte));
+ } else {
+ element_copy = gte;
+ }
+ }
+ index->pop_back();
+ element_copies.push_back(element_copy);
+ }
+ return instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(element_copies));
+}
+
+// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'.
+HloInstruction* InstructionCopier::Copy() {
+ ShapeIndex index;
+ HloInstruction* copy;
+ if (ShapeUtil::IsTuple(instruction_->shape())) {
+ copy = CopyTuple(instruction_, &index);
+ } else {
+ copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction_->shape(), HloOpcode::kCopy, instruction_));
+ }
+ for (HloInstruction* user : copy_users_) {
+ VLOG(2) << "Adding copy between instruction: " << instruction_->name()
+ << " and user: " << user->name();
+ instruction_->ReplaceUseWith(user, copy);
+ }
+ return copy;
+}
+
+} // anonymous namespace
+
+StatusOr<HloInstruction*> CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) {
+ auto copy_it = inserted_copies_.find(hlo);
+ if (copy_it == inserted_copies_.end()) {
+ HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie();
+ inserted_copies_.insert({hlo, copy});
+ return copy;
+ } else {
+ return copy_it->second;
+ }
+}
+
+StatusOr<bool> CopyInsertion::Run(HloModule* module) {
+ bool changed = false;
+ VLOG(2) << "CopyInsertion for module " << module->name();
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferLiveness> liveness,
+ BufferLiveness::Run(module, MakeUnique<DependencyHloOrdering>(module)));
+ auto& points_to_analysis = liveness->points_to_analysis();
+ XLA_VLOG_LINES(2, points_to_analysis.ToString());
+ XLA_VLOG_LINES(2, module->ToString());
+
+ // Gather references to all while body computations in 'module'.
+ std::unordered_set<const HloComputation*> while_body_computations;
+ // Gather references to all while instructions in 'module' by computation.
+ std::unordered_map<const HloComputation*, std::vector<HloInstruction*>>
+ while_instructions;
+ for (auto& computation : module->computations()) {
+ for (auto& instruction : computation->instructions()) {
+ if (instruction->opcode() != HloOpcode::kWhile) {
+ continue;
+ }
+ while_body_computations.insert(instruction->while_body());
+ auto it = while_instructions.find(computation.get());
+ if (it == while_instructions.end()) {
+ while_instructions.insert(
+ {computation.get(), std::vector<HloInstruction*>()});
+ }
+ while_instructions[computation.get()].emplace_back(instruction.get());
+ }
+ }
+
+ for (auto& computation : module->computations()) {
+ VLOG(2) << "computation " << computation->name();
+
+ // Collect instruction buffer indices to copy in 'instructions_to_copy'.
+ std::vector<InstructionCopier> instructions_to_copy;
+
+ // Add copies of while 'init' operand instructions (if needed).
+ // TODO(b/33301720) Remove redundant while instruction copies.
+ auto it = while_instructions.find(computation.get());
+ if (it != while_instructions.end()) {
+ for (auto& while_hlo : it->second) {
+ // Create InstructionCopier for init operand of while instruction.
+ HloInstruction* init_hlo = while_hlo->mutable_operand(0);
+ instructions_to_copy.push_back(
+ InstructionCopier(/*init_value=*/false, init_hlo, {while_hlo}));
+ InstructionCopier& init_copier = instructions_to_copy.back();
+ // Record 'init' buffer indices which point-to a Constant or Parameter.
+ TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant(
+ liveness->points_to_analysis()));
+ // Record indices necessary to colocate while and init operand buffers.
+ TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers(
+ liveness.get(), while_hlo));
+ }
+ }
+
+ // Create InstructionCopier for computation root instruction.
+ instructions_to_copy.push_back(InstructionCopier(
+ /*init_value=*/false, computation->root_instruction(), {}));
+ InstructionCopier& root_copier = instructions_to_copy.back();
+
+ if (while_body_computations.count(computation.get()) > 0) {
+ // Record root indices to copy for while body sub-computations.
+ // We do not need to call RecordIndicesWhichPointToParamOrConstant for
+ // the while root instruction here, because any neccessary copies needed
+ // to avoid constant or parameters in the output are handled by while.init
+ // operand copy insertion above (which will share an allocation).
+ TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers(
+ liveness.get(), computation->parameter_instruction(0)));
+ } else {
+ // Record root indices to copy for general computations.
+ TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant(
+ liveness->points_to_analysis()));
+ }
+
+ for (auto& to_copy : instructions_to_copy) {
+ if (to_copy.HasAllIndicesFalse()) {
+ continue;
+ }
+ changed = true;
+
+ // Copy instruction at recorded buffer indices.
+ HloInstruction* copy = to_copy.Copy();
+ if (to_copy.instruction() == computation->root_instruction()) {
+ computation->set_root_instruction(copy);
+ }
+ }
+ }
+
+ VLOG(3) << "After copy insertion for module " << module->name();
+ XLA_VLOG_LINES(3, module->ToString());
+
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/copy_insertion.h b/tensorflow/compiler/xla/service/copy_insertion.h
new file mode 100644
index 0000000000..4ea393ba94
--- /dev/null
+++ b/tensorflow/compiler/xla/service/copy_insertion.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
+
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// HLO pass which inserts a copy of the root instruction (creating a new root)
+// if the root is or points-to any constant or parameter instruction.
+// If the root instruction is a Tuple, only tuple elements which point to
+// constant or parameter instructions will be copied.
+// Copy insertion is necessary because constant and parameter arrays have
+// different lifetimes than computation results.
+class CopyInsertion : public HloPass {
+ public:
+ CopyInsertion() : HloPass("copy-insertion") {}
+ ~CopyInsertion() override {}
+
+ // Run the pass on the given module. Returns whether the module was changed
+ // (copies were inserted).
+ StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
+ // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making
+ // duplicate copies.
+ StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
+
+ // A map containing all copies inserted during the copy insertion pass. The
+ // key is the copied instruction and the value is the copy.
+ std::unordered_map<HloInstruction*, HloInstruction*> inserted_copies_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc
new file mode 100644
index 0000000000..e64da58dc7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc
@@ -0,0 +1,1153 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
+
+#include <set>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+
+namespace xla {
+namespace {
+
+class CopyInsertionTest : public HloTestBase {
+ protected:
+ void InsertCopies(HloModule* module) {
+ CopyInsertion copy_insertion;
+ EXPECT_IS_OK(copy_insertion.Run(module).status());
+
+ // Verify the points to set of the root of the computation after copy
+ // insertion contains no constants or parameters.
+ auto points_to_analysis =
+ TuplePointsToAnalysis::Run(module).ConsumeValueOrDie();
+ const std::set<const LogicalBuffer*> maybe_live_out_buffers =
+ points_to_analysis
+ ->GetPointsToSet(module->entry_computation()->root_instruction())
+ .CreateFlattenedSet();
+ for (const LogicalBuffer* buffer : maybe_live_out_buffers) {
+ EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kConstant);
+ EXPECT_NE(buffer->instruction()->opcode(), HloOpcode::kParameter);
+ }
+ }
+
+ // OperandTree is a test helper class that simplifies the expression of
+ // an expected tree of operands (starting at some root instruction) in a
+ // unit test.
+ // Each HLO instruction is represented as a node in the OperandTree.
+ struct OperandTree {
+ // The expected opcode for this OperandTree node.
+ HloOpcode opcode;
+ // The set of operands expected for this OperandTree node.
+ std::vector<OperandTree> operands;
+ // If non-null, a pointer to the expected HloInstruction at this node.
+ const HloInstruction* instruction = nullptr;
+
+ // Returns a mutable reference to operand 'i' of this node.
+ OperandTree& op(int i) {
+ if (i >= operands.size()) {
+ operands.resize(i + 1);
+ }
+ return operands[i];
+ }
+
+ // Check that 'instruction' and its operands match expected values recorded
+ // in OperandTree.
+ void Check(const HloInstruction* instruction) {
+ EXPECT_EQ(opcode, instruction->opcode());
+ if (instruction != nullptr) {
+ EXPECT_EQ(instruction, instruction);
+ }
+ if (operands.empty()) {
+ return;
+ }
+ EXPECT_EQ(operands.size(), instruction->operand_count());
+ for (int i = 0; i < instruction->operand_count(); ++i) {
+ operands[i].Check(instruction->operand(i));
+ }
+ }
+ };
+};
+
+#define EXPECT_INST(A, E...) EXPECT_EQ(A, (std::set<HloInstruction*>{E}))
+
+TEST_F(CopyInsertionTest, SingleParameter) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({x}));
+
+ EXPECT_INST(x->users(), tuple);
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check path from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, SingleConstant) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({constant}));
+
+ EXPECT_INST(constant->users(), tuple);
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check path from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) {
+ // Create a computation with more than one constant and parameter. Only one of
+ // each constant/parameter is pointed to by the output tuple. Only these
+ // instructions should be copied.
+ auto builder = HloComputation::Builder(TestName());
+
+ HloInstruction* constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+
+ HloInstruction* x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x"));
+ HloInstruction* y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "y"));
+
+ HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, constant1, y));
+
+ builder.AddInstruction(HloInstruction::CreateTuple({constant2, x, add}));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // "constant2" and parameter "x" are pointed to by the tuple and should be
+ // copied.
+
+ // Check all paths from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).instruction = old_root;
+
+ op_tree.op(2).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(2).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(2).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, AmbiguousPointsToSet) {
+ // Create a computation using select which has an ambiguous points-to set for
+ // the computation result. Verify that copies are added properly.
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ HloInstruction* constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
+
+ HloInstruction* tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ HloInstruction* tuple2 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant3, constant2}));
+
+ HloInstruction* pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+
+ EXPECT_INST(constant1->users(), tuple1);
+ EXPECT_INST(constant2->users(), tuple1, tuple2);
+ EXPECT_INST(constant3->users(), tuple2);
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check all paths from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kSelect;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kSelect;
+ op_tree.op(1).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, BitcastParameter) {
+ // The output of a bitcast is its operand (same buffer), so a bitcast
+ // parameter feeding the result must have a copy added.
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
+ HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_INST(x->users(), bitcast);
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check path from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kCopy;
+ op_tree.op(0).opcode = HloOpcode::kBitcast;
+ op_tree.op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, BitcastConstant) {
+ // The output of a bitcast is its operand (same buffer), so a bitcast
+ // constant feeding the result must have a copy added.
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 42.0})));
+ HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_INST(constant->users(), bitcast);
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check path from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kCopy;
+ op_tree.op(0).opcode = HloOpcode::kBitcast;
+ op_tree.op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, BitcastTupleElementParameter) {
+ // Same as BitcastParameter, but the bitcast is wrapped in a tuple.
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {4}), "x"));
+ HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, x));
+ builder.AddInstruction(HloInstruction::CreateTuple({bitcast}));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(1, x->user_count());
+ EXPECT_EQ(*x->users().begin(), bitcast);
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check path from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, NestedTupleParameter) {
+ // Construct a trivial computation where the root of the computation is a
+ // nested tuple-shaped parameter. The parameter should be deep copied and the
+ // copy should be the root of the computation.
+ auto builder = HloComputation::Builder(TestName());
+
+ // Param shape is: ((F32[], S32[1,2,3]), F32[42])
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeShape(S32, {1, 2, 3})}),
+ ShapeUtil::MakeShape(F32, {42})}),
+ "param0"));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(HloOpcode::kParameter,
+ module.entry_computation()->root_instruction()->opcode());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+ EXPECT_NE(old_root, new_root);
+
+ // Check all paths from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).op(0).op(0).opcode = HloOpcode::kParameter;
+ op_tree.op(0).op(0).op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(1).op(0).op(0).op(0).opcode = HloOpcode::kParameter;
+ op_tree.op(0).op(1).op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kParameter;
+ op_tree.op(1).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) {
+ // Construct a computation where the root of the computation is a tuple
+ // element of a nested tuple-shaped parameter.
+ auto builder = HloComputation::Builder(TestName());
+
+ // Param shape is: ((F32[], S32[1,2,3]), F32[42])
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {}),
+ ShapeUtil::MakeShape(S32, {1, 2, 3})}),
+ ShapeUtil::MakeShape(F32, {42})}),
+ "param0"));
+
+ // The return value of the computation is the zero-th elemnt of the nested
+ // tuple. This element is itself a tuple.
+ auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(param->shape(), {0}), param, 0));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(gte, module.entry_computation()->root_instruction());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check all paths from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) {
+ // Create a computation using select which has an ambiguous points-to set for
+ // the top-level buffer of the root of the computation. Verify that a shallow
+ // copy is added.
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+
+ HloInstruction* tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ HloInstruction* tuple2 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant2, constant1}));
+
+ HloInstruction* pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+ HloInstruction* gte =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(select->shape(), {0}), select, 0));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(gte, module.entry_computation()->root_instruction());
+
+ HloInstruction* old_root = module.entry_computation()->root_instruction();
+ InsertCopies(&module);
+ HloInstruction* new_root = module.entry_computation()->root_instruction();
+
+ // Check path from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kCopy;
+ op_tree.op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+class WhileCopyInsertionTest : public CopyInsertionTest {
+ protected:
+ WhileCopyInsertionTest() : module_(TestName()) {}
+
+ // Builds a While condition computation which reads the induction variable
+ // from the tuple parameter, and returns a predicate indicating whether this
+ // value is less than the constant '10'.
+ // The parameter 'nested' specifies the loop state shape from which to
+ // read the induction variable.
+ std::unique_ptr<HloComputation> BuildConditionComputation(
+ bool nested = false) {
+ auto builder = HloComputation::Builder(TestName() + ".Condition");
+ auto limit_const = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(10)));
+ const Shape& loop_state_shape =
+ nested ? nested_loop_state_shape_ : loop_state_shape_;
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape, "loop_state"));
+ auto induction_variable =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ limit_const->shape(), loop_state, 0));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
+ induction_variable, limit_const));
+ return builder.Build();
+ }
+
+ // Builds a While body computation with one output tuple element dependent on
+ // both input tuple elements.
+ // EX:
+ // Body({in0, in1})
+ // out0 = Add(in0, 1)
+ // out1 = Add(BCast(in0), in1)
+ // Tuple(out0, out1)
+ std::unique_ptr<HloComputation> BuildDependentBodyComputation() {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ // Create param instruction to access loop state.
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ // Update the induction variable GTE(0).
+ auto induction_variable =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ induction_variable_shape_, loop_state, 0));
+ auto inc = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
+ // Update data GTE(1).
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
+ // Use 'induction_variable' in computation with no path to output tuple.
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kAdd, data, update));
+ // Create output Tuple.
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+ return builder.Build();
+ }
+
+ // Builds a While body computation with read-only tuple element 0.
+ // both input tuple elements.
+ // EX:
+ // Body({in0, in1})
+ // out0 = in0
+ // out1 = Add(BCast(in0), in1)
+ // Tuple(out0, out1)
+ std::unique_ptr<HloComputation> BuildDependentBodyOneReadOnlyComputation() {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ // Create param instruction to access loop state.
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ // Update the induction variable GTE(0).
+ auto induction_variable =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ induction_variable_shape_, loop_state, 0));
+ // Update data GTE(1).
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
+ // Use 'induction_variable' in computation with no path to output tuple.
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kAdd, data, update));
+ // Create output Tuple.
+ builder.AddInstruction(
+ HloInstruction::CreateTuple({induction_variable, add1}));
+ return builder.Build();
+ }
+
+ // Builds a While body computation with independent outputs.
+ // EX:
+ // Body({in0, in1})
+ // out0 = Add(in0, 1)
+ // out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
+ // Tuple(out0, out1)
+ std::unique_ptr<HloComputation> BuildIndependentBodyComputation() {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ // Create param instruction to access loop state.
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ // Update the induction variable GTE(0).
+ auto induction_variable =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ induction_variable_shape_, loop_state, 0));
+ auto inc = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+ // add0 = Add(in0, 1)
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
+ // Update data GTE(1).
+ auto data = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1));
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ // add0 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kAdd, data, update));
+ // Create output Tuple.
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+ return builder.Build();
+ }
+
+ // Builds a While body computation with the following nested tuple
+ // sub-computation:
+ // |
+ // GTE(loop_state, 1)
+ // / \
+ // GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
+ // | |
+ // Add Reverse
+ // | |
+ std::unique_ptr<HloComputation> BuildNestedBodyComputation() {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ // Create param instruction to access loop state.
+ auto loop_state = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, nested_loop_state_shape_, "loop_state"));
+ // Update GTE(0).
+ auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ induction_variable_shape_, loop_state, 0));
+ auto inc = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ gte0->shape(), HloOpcode::kAdd, gte0, inc));
+
+ // GTE(loop_state, 1)
+ auto gte1 = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ nested_tuple_shape_, loop_state, 1));
+ // GTE(GTE(loop_state, 1), 0) -> Add
+ auto gte10 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0));
+ auto update10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ auto add10 = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kAdd, gte10, update10));
+
+ // GTE(GTE(loop_state, 1), 1) -> Reverse
+ auto gte11 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, gte1, 1));
+ auto rev11 = builder.AddInstruction(
+ HloInstruction::CreateReverse(data_shape_, gte11, {0}));
+
+ // Create output Tuple.
+ auto inner_tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({add10, rev11}));
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, inner_tuple}));
+ return builder.Build();
+ }
+
+ // Builds a While instruction using 'condition' and 'body' sub-computations.
+ // Init operand is initialized to zeros of appropriate shape.
+ void BuildWhileInstruction(HloComputation* condition, HloComputation* body,
+ bool nested = false) {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+ auto induction_var_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+
+ if (nested) {
+ auto inner_init = builder.AddInstruction(
+ HloInstruction::CreateTuple({data_init, data_init}));
+ auto loop_state_init = builder.AddInstruction(
+ HloInstruction::CreateTuple({induction_var_init, inner_init}));
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_state_shape_, condition, body, loop_state_init));
+ module_.AddEntryComputation(builder.Build());
+ return;
+ }
+
+ auto loop_state_init = builder.AddInstruction(
+ HloInstruction::CreateTuple({induction_var_init, data_init}));
+ builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_state_shape_, condition, body, loop_state_init));
+ module_.AddEntryComputation(builder.Build());
+ }
+
+ HloInstruction* BuildWhileInstruction_InitPointsToConstant() {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
+ &builder);
+ }
+
+ HloInstruction* BuildWhileInstruction_InitPointsToParameter() {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape_, "data_init"));
+ return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
+ &builder);
+ }
+
+ HloInstruction* BuildWhileInstruction_InitPointsToAmbiguous() {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto v1 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, one, {1}));
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto v2 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+
+ auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({v1, v2}));
+ auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1}));
+
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ auto data_init = builder.AddInstruction(HloInstruction::CreateTernary(
+ nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2));
+
+ return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
+ data_init, &builder);
+ }
+
+ HloInstruction* BuildWhileInstruction_InitPointsToNonDistinct() {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto one_vec = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, one, {1}));
+ auto data_init =
+ builder.AddInstruction(HloInstruction::CreateTuple({one_vec, one_vec}));
+
+ return BuildWhileInstructionWithCustomInit(nested_loop_state_shape_,
+ data_init, &builder);
+ }
+
+ HloInstruction* BuildWhileInstruction_InitPointsToInterfering() {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, one, {1}));
+ auto one_vec = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f})));
+ // Take a reference to 'data_init' to make it interfere with while result.
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kAdd, data_init, one_vec));
+
+ return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init,
+ &builder);
+ }
+
+ HloInstruction* BuildWhileInstructionWithCustomInit(
+ const Shape& loop_state_shape, HloInstruction* data_init,
+ HloComputation::Builder* builder) {
+ auto induction_var_init = builder->AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation());
+ auto body =
+ module_.AddEmbeddedComputation(BuildIndependentBodyComputation());
+ auto loop_state_init = builder->AddInstruction(
+ HloInstruction::CreateTuple({induction_var_init, data_init}));
+ auto while_hlo = builder->AddInstruction(HloInstruction::CreateWhile(
+ loop_state_shape, condition, body, loop_state_init));
+ module_.AddEntryComputation(builder->Build());
+ return while_hlo;
+ }
+
+ HloModule module_;
+ Shape induction_variable_shape_ = ShapeUtil::MakeShape(S32, {});
+ Shape data_shape_ = ShapeUtil::MakeShape(F32, {8});
+ Shape loop_state_shape_ =
+ ShapeUtil::MakeTupleShape({induction_variable_shape_, data_shape_});
+ Shape nested_tuple_shape_ =
+ ShapeUtil::MakeTupleShape({data_shape_, data_shape_});
+ Shape nested_loop_state_shape_ = ShapeUtil::MakeTupleShape(
+ {induction_variable_shape_, nested_tuple_shape_});
+ Shape condition_result_shape_ = ShapeUtil::MakeShape(PRED, {});
+};
+
+// Tests while body computation with independent tuple elements:
+//
+// While.Body({in0, in1})
+// out0 = Add(in0, 1)
+// out1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1})
+// Tuple(out0, out1)
+//
+// CopyInsertion pass should not generate any copies.
+//
+TEST_F(WhileCopyInsertionTest, IndependentTupleElements) {
+ auto condition = module_.AddEmbeddedComputation(BuildConditionComputation());
+ auto body = module_.AddEmbeddedComputation(BuildIndependentBodyComputation());
+ BuildWhileInstruction(condition, body);
+
+ HloInstruction* old_root = body->root_instruction();
+ InsertCopies(&module_);
+ HloInstruction* new_root = body->root_instruction();
+
+ // No copies should be inserted so root should not be updated.
+ CHECK_EQ(old_root, new_root);
+}
+
+// Tests while body computation with dependent tuple elements:
+//
+// While.Body({in0, in1})
+// out0 = Add(in0, 1)
+// out1 = Add(BCast(in0), in1)
+// Tuple(out0, out1)
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old root
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy |
+// \ /
+// Tuple // new root
+//
+TEST_F(WhileCopyInsertionTest, DependentTupleElements) {
+ auto condition = module_.AddEmbeddedComputation(BuildConditionComputation());
+ auto body = module_.AddEmbeddedComputation(BuildDependentBodyComputation());
+ BuildWhileInstruction(condition, body);
+
+ HloInstruction* old_root = body->root_instruction();
+ InsertCopies(&module_);
+ HloInstruction* new_root = body->root_instruction();
+
+ // Check all paths from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+// Tests while body computation with read-only tuple element 0:
+//
+// PARAMETER
+// / \
+// GTE(0) GTE(1)
+// | \ |
+// | BCAST |
+// | \ |
+// | ADD
+// | |
+// \ /
+// TUPLE (root)
+//
+// CopyInsertion pass should not generate any copies.
+//
+TEST_F(WhileCopyInsertionTest, DependentTupleElements_OneReadOnly) {
+ auto condition = module_.AddEmbeddedComputation(BuildConditionComputation());
+ auto body = module_.AddEmbeddedComputation(
+ BuildDependentBodyOneReadOnlyComputation());
+ BuildWhileInstruction(condition, body);
+
+ HloInstruction* old_root = body->root_instruction();
+ InsertCopies(&module_);
+ HloInstruction* new_root = body->root_instruction();
+
+ // No copies should be inserted so root should not be updated.
+ CHECK_EQ(old_root, new_root);
+}
+
+// Tests while body computation with nested tuple elements:
+//
+// |
+// GTE(loop_state, 1)
+// / \
+// GTE(GTE(loop_state, 1), 0) GTE(GTE(loop_state, 1), 1)
+// | |
+// Add Reverse
+// | |
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old root
+// / \
+// / \
+// GTE(0) GTE(1)
+// | / \
+// | / \
+// | GTE(0) GTE(1)
+// | | |
+// | | Copy
+// | | |
+// \ | /
+// \ Tuple // "inner" tuple.
+// \ /
+// \ /
+// Tuple // new root
+//
+TEST_F(WhileCopyInsertionTest, NestedTupleElements) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(true));
+ auto body = module_.AddEmbeddedComputation(BuildNestedBodyComputation());
+ BuildWhileInstruction(condition, body, true);
+
+ HloInstruction* old_root = body->root_instruction();
+ InsertCopies(&module_);
+ HloInstruction* new_root = body->root_instruction();
+
+ // Check all paths from 'new_root' to 'old_root'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).opcode = HloOpcode::kTuple;
+
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.op(1).op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_root;
+
+ op_tree.Check(new_root);
+}
+
+// Tests while init instruction which points-to a constant.
+//
+// init = Tuple(Constant(S32, {}), Constant(F32, {8}))
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old init
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy Copy
+// \ /
+// Tuple // new init
+//
+TEST_F(WhileCopyInsertionTest, InitPointsToConstant) {
+ auto while_hlo = BuildWhileInstruction_InitPointsToConstant();
+ auto old_init = while_hlo->operand(0);
+ InsertCopies(&module_);
+ auto new_init = while_hlo->operand(0);
+
+ // Check all paths from 'new_init' to 'old_init'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).instruction = old_init;
+
+ op_tree.Check(new_init);
+}
+
+// Tests while init instruction which points-to a parameter.
+//
+// init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old init
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy Copy
+// \ /
+// Tuple // new init
+//
+TEST_F(WhileCopyInsertionTest, InitPointsToParameter) {
+ auto while_hlo = BuildWhileInstruction_InitPointsToParameter();
+ auto old_init = while_hlo->operand(0);
+ InsertCopies(&module_);
+ auto new_init = while_hlo->operand(0);
+
+ // Check all paths from 'new_init' to 'old_init'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).instruction = old_init;
+
+ op_tree.Check(new_init);
+}
+
+// Tests while init instruction which has an ambiguous points-to set.
+//
+// select = Select(pred, tuple1, tuple2)
+// init = Tuple(Constant(S32, {}), Parameter(F32, {8}))
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old init
+// / \
+// / \
+// GTE(0) GTE(1)
+// | / \
+// | / \
+// | GTE(0) GTE(1)
+// | | |
+// Copy Copy Copy
+// | | |
+// \ | /
+// \ Tuple
+// \ /
+// \ /
+// Tuple // new init
+//
+TEST_F(WhileCopyInsertionTest, InitPointsToAmbiguous) {
+ auto while_hlo = BuildWhileInstruction_InitPointsToAmbiguous();
+ auto old_init = while_hlo->operand(0);
+ InsertCopies(&module_);
+ auto new_init = while_hlo->operand(0);
+
+ // Check all paths from 'new_init' to 'old_init'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).opcode = HloOpcode::kTuple;
+
+ op_tree.op(1).op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.Check(new_init);
+}
+
+// Tests while init instruction which has a non-distinct points-to set.
+//
+// init = Tuple(Constant(S32, {}), Tuple({vec_one, vec_one}))
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old init
+// / \
+// / \
+// GTE(0) GTE(1)
+// | / \
+// | / \
+// | GTE(0) GTE(1)
+// | | |
+// Copy Copy Copy
+// | | |
+// \ | /
+// \ Tuple
+// \ /
+// \ /
+// Tuple // new init
+//
+TEST_F(WhileCopyInsertionTest, InitPointsToNonDistinct) {
+ auto while_hlo = BuildWhileInstruction_InitPointsToNonDistinct();
+ auto old_init = while_hlo->operand(0);
+ InsertCopies(&module_);
+ auto new_init = while_hlo->operand(0);
+
+ // Check all paths from 'new_init' to 'old_init'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).opcode = HloOpcode::kTuple;
+
+ op_tree.op(1).op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(1).op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(1).op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(1).op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.Check(new_init);
+}
+
+// Tests while init instruction buffer which interfers with while result buffer.
+//
+// init_data = Broadcast(...)
+// add_unrelated = Add(init_data) // takes a reference to cause interference
+// init = Tuple(Constant(S32, {}), init_data))
+//
+// CopyInsertion pass should generate:
+//
+// Tuple // old init
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy Copy
+// \ /
+// Tuple // new init
+//
+TEST_F(WhileCopyInsertionTest, InitPointsToInterfering) {
+ auto while_hlo = BuildWhileInstruction_InitPointsToInterfering();
+ auto old_init = while_hlo->operand(0);
+ InsertCopies(&module_);
+ auto new_init = while_hlo->operand(0);
+
+ // Check all paths from 'new_init' to 'old_init'.
+ OperandTree op_tree;
+ op_tree.opcode = HloOpcode::kTuple;
+
+ op_tree.op(0).opcode = HloOpcode::kCopy;
+ op_tree.op(0).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(0).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(0).op(0).op(0).instruction = old_init;
+
+ op_tree.op(1).opcode = HloOpcode::kCopy;
+ op_tree.op(1).op(0).opcode = HloOpcode::kGetTupleElement;
+ op_tree.op(1).op(0).op(0).opcode = HloOpcode::kTuple;
+ op_tree.op(1).op(0).op(0).instruction = old_init;
+
+ op_tree.Check(new_init);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
new file mode 100644
index 0000000000..8af54b11bb
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -0,0 +1,529 @@
+# Description:
+# LLVM-based CPU backend for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [":friends"],
+ features = [
+ "-layering_check",
+ "-parse_headers",
+ ],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+load(":build_defs.bzl", "runtime_copts")
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "cpu_compiler",
+ srcs = ["cpu_compiler.cc"],
+ hdrs = ["cpu_compiler.h"],
+ deps = [
+ ":compiler_functor",
+ ":conv_canonicalization",
+ ":cpu_executable",
+ ":cpu_instruction_fusion",
+ ":cpu_parallelization_preparation",
+ ":disassembler",
+ ":ir_emission_utils",
+ ":ir_emitter",
+ ":layout_assignment",
+ ":parallel_cpu_executable",
+ ":simple_orc_jit",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/port:initialize",
+ "//tensorflow/compiler/xla/service:algebraic_simplifier",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:buffer_liveness",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:copy_insertion",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_cse",
+ "//tensorflow/compiler/xla/service:hlo_dce",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
+ "//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
+ "//tensorflow/compiler/xla/service:inliner",
+ "//tensorflow/compiler/xla/service:reshape_mover",
+ "//tensorflow/compiler/xla/service:transpose_folding",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep
+ "//tensorflow/core:lib", # fixdeps: keep
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:aarch64_code_gen", # fixdeps: keep
+ "@llvm//:aarch64_disassembler", # fixdeps: keep
+ "@llvm//:arm_code_gen", # fixdeps: keep
+ "@llvm//:arm_disassembler", # fixdeps: keep
+ "@llvm//:core",
+ "@llvm//:mc", # fixdeps: keep
+ "@llvm//:object",
+ "@llvm//:powerpc_code_gen", # fixdeps: keep
+ "@llvm//:powerpc_disassembler", # fixdeps: keep
+ "@llvm//:support",
+ "@llvm//:target", # fixdeps: keep
+ "@llvm//:x86_code_gen", # fixdeps: keep
+ "@llvm//:x86_disassembler", # fixdeps: keep
+ ],
+ alwayslink = True, # Contains compiler registration
+)
+
+cc_library(
+ name = "simple_orc_jit",
+ srcs = ["simple_orc_jit.cc"],
+ hdrs = ["simple_orc_jit.h"],
+ deps = [
+ ":compiler_functor",
+ ":cpu_runtime",
+ ":cpu_runtime_avx",
+ ":cpu_runtime_sse4_1",
+ ":disassembler",
+ ":runtime_conv2d",
+ ":runtime_matmul",
+ ":runtime_single_threaded_conv2d",
+ ":runtime_single_threaded_matmul",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ "@llvm//:mc", # fixdeps: keep
+ "@llvm//:orc_jit",
+ "@llvm//:support",
+ "@llvm//:target", # fixdeps: keep
+ ],
+)
+
+cc_library(
+ name = "cpu_executable",
+ srcs = ["cpu_executable.cc"],
+ hdrs = ["cpu_executable.h"],
+ deps = [
+ ":simple_orc_jit",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:tuple_points_to_analysis",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:orc_jit",
+ ],
+)
+
+cc_library(
+ name = "parallel_cpu_executable",
+ srcs = ["parallel_cpu_executable.cc"],
+ hdrs = ["parallel_cpu_executable.h"],
+ deps = [
+ ":cpu_runtime",
+ ":simple_orc_jit",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:orc_jit",
+ ],
+)
+
+cc_library(
+ name = "ir_emitter",
+ srcs = ["ir_emitter.cc"],
+ hdrs = ["ir_emitter.h"],
+ deps = [
+ ":cpu_runtime",
+ ":dot_op_emitter",
+ ":elemental_ir_emitter",
+ ":ir_emission_utils",
+ ":simple_orc_jit",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:elemental_ir_emitter",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:name_uniquer",
+ "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
+ "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/compiler/xla/service/llvm_ir:ops",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ "@llvm//:support",
+ ],
+)
+
+cc_library(
+ name = "dot_op_emitter",
+ srcs = ["dot_op_emitter.cc"],
+ hdrs = ["dot_op_emitter.h"],
+ deps = [
+ ":cpu_runtime",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_binary(
+ name = "sample_harness",
+ srcs = ["sample_harness.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "disassembler",
+ srcs = ["disassembler.cc"],
+ hdrs = ["disassembler.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ "@llvm//:mc",
+ "@llvm//:mc_disassembler",
+ "@llvm//:object",
+ "@llvm//:powerpc_disassembler", # fixdeps: keep
+ "@llvm//:support",
+ "@llvm//:target",
+ "@llvm//:x86_disassembler", # fixdeps: keep
+ ],
+)
+
+cc_library(
+ name = "compiler_functor",
+ srcs = ["compiler_functor.cc"],
+ hdrs = ["compiler_functor.h"],
+ deps = [
+ ":cpu_runtime",
+ ":cpu_runtime_avx",
+ ":cpu_runtime_sse4_1",
+ ":disassembler",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "@llvm//:analysis",
+ "@llvm//:core",
+ "@llvm//:ipo",
+ "@llvm//:mc",
+ "@llvm//:object",
+ "@llvm//:support",
+ "@llvm//:target",
+ ],
+)
+
+cc_library(
+ name = "cpu_runtime_sse4_1",
+ srcs = ["cpu_runtime_sse4_1.cc"],
+ hdrs = ["cpu_runtime_sse4_1.h"],
+ copts = ["-DEIGEN_AVOID_STL_ARRAY"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "cpu_runtime_avx",
+ srcs = ["cpu_runtime_avx.cc"],
+ hdrs = ["cpu_runtime_avx.h"],
+ copts = ["-DEIGEN_AVOID_STL_ARRAY"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+ alwayslink = True,
+)
+
+cc_library(
+ name = "cpu_runtime",
+ srcs = [
+ "cpu_runtime.cc",
+ "infeed_manager.cc",
+ ],
+ hdrs = [
+ "cpu_runtime.h",
+ "infeed_manager.h",
+ ],
+ copts = runtime_copts(),
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ ],
+)
+
+cc_library(
+ name = "runtime_conv2d",
+ srcs = [
+ "runtime_conv2d.cc",
+ "runtime_conv2d_impl.h",
+ ],
+ hdrs = ["runtime_conv2d.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core/kernels:eigen_helpers",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
+ name = "runtime_matmul",
+ srcs = ["runtime_matmul.cc"],
+ hdrs = ["runtime_matmul.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
+ name = "runtime_single_threaded_conv2d",
+ srcs = [
+ "runtime_conv2d_impl.h",
+ "runtime_single_threaded_conv2d.cc",
+ ],
+ hdrs = ["runtime_single_threaded_conv2d.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core/kernels:eigen_helpers",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_library(
+ name = "runtime_single_threaded_matmul",
+ srcs = ["runtime_single_threaded_matmul.cc"],
+ hdrs = ["runtime_single_threaded_matmul.h"],
+ copts = runtime_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_test(
+ name = "cpu_runtime_test",
+ srcs = ["cpu_runtime_test.cc"],
+ deps = [
+ ":cpu_runtime",
+ ":runtime_matmul",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_test(
+ name = "infeed_manager_test",
+ srcs = ["infeed_manager_test.cc"],
+ deps = [
+ ":cpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "cpu_instruction_fusion",
+ srcs = ["cpu_instruction_fusion.cc"],
+ hdrs = ["cpu_instruction_fusion.h"],
+ deps = [
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:instruction_fusion",
+ ],
+)
+
+cc_library(
+ name = "cpu_parallelization_preparation",
+ srcs = ["cpu_parallelization_preparation.cc"],
+ hdrs = ["cpu_parallelization_preparation.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/compiler/xla/service:tuple_points_to_analysis",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "elemental_ir_emitter",
+ srcs = ["elemental_ir_emitter.cc"],
+ hdrs = ["elemental_ir_emitter.h"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:elemental_ir_emitter",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "ir_emission_utils",
+ srcs = ["ir_emission_utils.cc"],
+ hdrs = ["ir_emission_utils.h"],
+ deps = [
+ ":cpu_runtime",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ ],
+)
+
+cc_library(
+ name = "layout_assignment",
+ srcs = ["layout_assignment.cc"],
+ hdrs = ["layout_assignment.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:layout_assignment",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "conv_canonicalization",
+ srcs = ["conv_canonicalization.cc"],
+ hdrs = ["conv_canonicalization.h"],
+ deps = [
+ ":cpu_runtime",
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ ],
+)
+
+cc_test(
+ name = "conv_canonicalization_test",
+ srcs = ["conv_canonicalization_test.cc"],
+ deps = [
+ ":conv_canonicalization",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/service/cpu/build_defs.bzl b/tensorflow/compiler/xla/service/cpu/build_defs.bzl
new file mode 100644
index 0000000000..b4b5219751
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/build_defs.bzl
@@ -0,0 +1,11 @@
+"""build_defs for service/cpu."""
+
+def runtime_copts():
+ """Returns copts used for CPU runtime libraries."""
+ return (["-DEIGEN_AVOID_STL_ARRAY"] +
+ select({
+ "//tensorflow:android_arm": ["-mfpu=neon"],
+ "//conditions:default": []}) +
+ select({
+ "//tensorflow:android": ["-O2"],
+ "//conditions:default": []}))
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
new file mode 100644
index 0000000000..89b3302bca
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
+
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/StringRef.h"
+#include "external/llvm/include/llvm/Analysis/TargetLibraryInfo.h"
+#include "external/llvm/include/llvm/Analysis/TargetTransformInfo.h"
+#include "external/llvm/include/llvm/ExecutionEngine/ObjectMemoryBuffer.h"
+#include "external/llvm/include/llvm/IR/LegacyPassManager.h"
+#include "external/llvm/include/llvm/IR/Verifier.h"
+#include "external/llvm/include/llvm/MC/MCContext.h"
+#include "external/llvm/include/llvm/Object/ObjectFile.h"
+#include "external/llvm/include/llvm/Support/raw_ostream.h"
+#include "external/llvm/include/llvm/Target/TargetMachine.h"
+#include "external/llvm/include/llvm/Transforms/IPO.h"
+#include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h"
+#include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace cpu {
+
+/* static */ CompilerFunctor::VectorIntrinsics
+CompilerFunctor::AllIntrinsics() {
+ VectorIntrinsics intrinsics;
+ intrinsics.sse_intrinsics = true;
+ intrinsics.avx_intrinsics = true;
+ return intrinsics;
+}
+
+llvm::object::OwningBinary<llvm::object::ObjectFile> CompilerFunctor::
+operator()(llvm::Module& module) const {
+ llvm::legacy::PassManager module_passes;
+ llvm::legacy::FunctionPassManager function_passes(&module);
+
+ VLOG(2) << "IR before optimizations";
+ XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
+ legacy_flags::CompilerFunctorFlags* flags =
+ legacy_flags::GetCompilerFunctorFlags();
+ string dump_path = flags->xla_debug_cpu_dump_ir;
+ if (!dump_path.empty()) {
+ std::unique_ptr<tensorflow::WritableFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewAppendableFile(dump_path, &f));
+ TF_CHECK_OK(f->Append(llvm_ir::DumpModuleToString(module)));
+ TF_CHECK_OK(f->Close());
+ }
+
+ // Build up optimization pipeline.
+ AddOptimizationPasses(&module_passes, &function_passes);
+
+ // Run optimization passes on module.
+ function_passes.doInitialization();
+ for (auto func = module.begin(); func != module.end(); ++func) {
+ function_passes.run(*func);
+ }
+ function_passes.doFinalization();
+ module_passes.run(module);
+
+ // Buffer for holding machine code prior to constructing the ObjectFile.
+ llvm::SmallVector<char, 0> stream_buffer;
+ llvm::raw_svector_ostream ostream(stream_buffer);
+
+ VLOG(2) << "IR after optimizations";
+ XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(module));
+
+ // Generate code.
+ llvm::MCContext* mc_context;
+ llvm::legacy::PassManager codegen_passes;
+ target_machine_->addPassesToEmitMC(codegen_passes, mc_context, ostream);
+ codegen_passes.run(module);
+
+ // Construct ObjectFile from machine code buffer.
+ std::unique_ptr<llvm::MemoryBuffer> memory_buffer(
+ new llvm::ObjectMemoryBuffer(std::move(stream_buffer)));
+ llvm::Expected<std::unique_ptr<llvm::object::ObjectFile>>
+ object_file_or_error = llvm::object::ObjectFile::createObjectFile(
+ memory_buffer->getMemBufferRef());
+ CHECK(object_file_or_error);
+
+ std::unique_ptr<llvm::object::ObjectFile> object_file =
+ std::move(object_file_or_error.get());
+ if (VLOG_IS_ON(2)) {
+ StatusOr<string> disassembly_status =
+ disassembler_->DisassembleObjectFile(*object_file);
+ if (disassembly_status.ok()) {
+ XLA_VLOG_LINES(2, disassembly_status.ValueOrDie());
+ }
+ }
+
+ return llvm::object::OwningBinary<llvm::object::ObjectFile>(
+ std::move(object_file), std::move(memory_buffer));
+}
+
+namespace {
+// Returns the set of vectorized library functions supported for the target.
+std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
+ llvm::Triple::ArchType arch, llvm::StringRef feature_string,
+ CompilerFunctor::VectorIntrinsics const& available_intrinsics) {
+ std::vector<llvm::VecDesc> vector_functions;
+
+ const llvm::VecDesc four_wide_vector_functions[] = {
+ {"expf", runtime::kExpV4F32, 4},
+ {"llvm.exp.f32", runtime::kExpV4F32, 4},
+
+ {"logf", runtime::kLogV4F32, 4},
+ {"llvm.log.f32", runtime::kLogV4F32, 4},
+
+ {"tanhf", runtime::kTanhV4F32, 4},
+ {"llvm.tanh.f32", runtime::kTanhV4F32, 4},
+ };
+
+ const llvm::VecDesc eight_wide_vector_functions[] = {
+ {"expf", runtime::kExpV8F32, 8},
+ {"llvm.exp.f32", runtime::kExpV8F32, 8},
+
+ {"logf", runtime::kLogV8F32, 8},
+ {"llvm.log.f32", runtime::kLogV8F32, 8},
+
+ {"tanhf", runtime::kTanhV8F32, 8},
+ {"llvm.tanh.f32", runtime::kTanhV8F32, 8},
+ };
+
+ // Our vectorized library calls are currently implement by calling into Eigen.
+ // As such, only emit calls to these routines if --xla_cpu_use_eigen is
+ // enabled.
+ legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ if (flags->xla_cpu_use_eigen &&
+ (arch == llvm::Triple::x86 || llvm::Triple::x86_64)) {
+ llvm::SmallVector<llvm::StringRef, 32> features;
+ feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
+ if (std::find(features.begin(), features.end(), "+sse4.1") !=
+ features.end() &&
+ available_intrinsics.sse_intrinsics) {
+ vector_functions.insert(vector_functions.end(),
+ std::begin(four_wide_vector_functions),
+ std::end(four_wide_vector_functions));
+ }
+ if (std::find(features.begin(), features.end(), "+avx") != features.end() &&
+ available_intrinsics.avx_intrinsics) {
+ vector_functions.insert(vector_functions.end(),
+ std::begin(eight_wide_vector_functions),
+ std::end(eight_wide_vector_functions));
+ }
+ }
+ return vector_functions;
+}
+} // namespace
+
+void CompilerFunctor::AddOptimizationPasses(
+ llvm::legacy::PassManagerBase* module_passes,
+ llvm::legacy::FunctionPassManager* function_passes) const {
+ llvm::Triple target_triple(target_machine_->getTargetTriple());
+ auto target_library_info_impl =
+ MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
+ target_library_info_impl->addVectorizableFunctions(
+ VectorFunctionsForTargetLibraryInfoImpl(
+ target_triple.getArch(), target_machine_->getTargetFeatureString(),
+ available_intrinsics_));
+ module_passes->add(
+ new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
+ module_passes->add(createTargetTransformInfoWrapperPass(
+ target_machine_->getTargetIRAnalysis()));
+
+ module_passes->add(llvm::createVerifierPass());
+
+ llvm::PassManagerBuilder builder;
+ builder.OptLevel = opt_level_;
+ builder.SizeLevel = 0;
+
+ if (opt_level_ > 1) {
+ builder.Inliner = llvm::createFunctionInliningPass();
+ } else {
+ // Only inline functions marked with "alwaysinline".
+ builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
+ }
+
+ builder.DisableUnitAtATime = false;
+ builder.DisableUnrollLoops = opt_level_ == 0;
+ builder.LoopVectorize = opt_level_ > 0;
+ builder.SLPVectorize = opt_level_ > 1;
+
+ builder.populateFunctionPassManager(*function_passes);
+ builder.populateModulePassManager(*module_passes);
+
+ module_passes->add(llvm::createVerifierPass());
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h
new file mode 100644
index 0000000000..17dadebe97
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
+
+#include "external/llvm/include/llvm/IR/LegacyPassManager.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/Object/ObjectFile.h"
+#include "external/llvm/include/llvm/Target/TargetMachine.h"
+#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace cpu {
+
+// Functor class for compiling an LLVM module down to an object file. For use by
+// Orc JIT compile layer.
+class CompilerFunctor {
+ public:
+ // Describes the set of vector intrinsics available to the generated code.
+ struct VectorIntrinsics {
+ bool sse_intrinsics;
+ bool avx_intrinsics;
+ };
+
+ // Returns a VectorIntrinsics where all intrinsics are available.
+ static VectorIntrinsics AllIntrinsics();
+
+ explicit CompilerFunctor(llvm::TargetMachine* target_machine,
+ const Disassembler* disassembler, int opt_level,
+ const VectorIntrinsics& available_intrinsics)
+ : target_machine_(target_machine),
+ disassembler_(CHECK_NOTNULL(disassembler)),
+ opt_level_(opt_level),
+ available_intrinsics_(available_intrinsics) {}
+
+ // Compile a Module to an ObjectFile.
+ llvm::object::OwningBinary<llvm::object::ObjectFile> operator()(
+ llvm::Module& module) const; // NOLINT
+
+ private:
+ // Populates the given pass managers based on the optimization level.
+ void AddOptimizationPasses(
+ llvm::legacy::PassManagerBase* module_passes,
+ llvm::legacy::FunctionPassManager* function_passes) const;
+
+ llvm::TargetMachine* target_machine_;
+ const Disassembler* disassembler_;
+ const unsigned opt_level_;
+ const VectorIntrinsics available_intrinsics_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_COMPILER_FUNCTOR_H_
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
new file mode 100644
index 0000000000..2c3fc0abbc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc
@@ -0,0 +1,148 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace cpu {
+
+StatusOr<bool> ConvCanonicalization::Run(HloModule* module) {
+ legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ if (!flags->xla_cpu_use_eigen) {
+ return false;
+ }
+
+ bool changed = false;
+ for (HloInstruction* hlo :
+ module->entry_computation()->MakeInstructionPostOrder()) {
+ if (hlo->opcode() == HloOpcode::kConvolution &&
+ !PotentiallyImplementedAsEigenConvolution(*hlo)) {
+ const ConvolutionDimensionNumbers& dnums =
+ hlo->convolution_dimension_numbers();
+ auto batch_dim = dnums.batch_dimension();
+ auto feature_dim = dnums.feature_dimension();
+ auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension();
+ auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension();
+
+ int num_spatial_dims = dnums.spatial_dimensions_size();
+ int num_dims = num_spatial_dims + 2;
+
+ // A canonical convolution's dimension numbers need to satisfy the
+ // following conditions (see cs/PotentiallyImplementedAsEigenConvolution).
+ //
+ // - the input is in NHWC or NWHC order.
+ // - the kernel is in HWIO or WHIO order.
+ // - the spatial dimensions are in the same relative order in the input,
+ // kernel and output.
+ //
+ // For simplicity, as a first step, we reshape the input and filter to
+ // NHWC and HWIO order, respectively. This may lose precision but not
+ // break the soundness.
+ HloInstruction* input = hlo->mutable_operand(0);
+
+ std::vector<int64> new_input_dim_order(num_dims);
+ std::vector<int64> new_input_dims(num_dims);
+ new_input_dim_order[0] = batch_dim;
+ new_input_dims[0] = input->shape().dimensions(batch_dim);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ new_input_dim_order[i + 1] = dnums.spatial_dimensions(i);
+ new_input_dims[i + 1] =
+ input->shape().dimensions(dnums.spatial_dimensions(i));
+ }
+ new_input_dim_order[num_dims - 1] = feature_dim;
+ new_input_dims[num_dims - 1] = input->shape().dimensions(feature_dim);
+
+ Shape new_input_shape =
+ ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims);
+ HloInstruction* new_input = module->entry_computation()->AddInstruction(
+ HloInstruction::CreateTranspose(new_input_shape, input,
+ new_input_dim_order));
+
+ HloInstruction* kernel = hlo->mutable_operand(1);
+
+ std::vector<int64> new_kernel_dim_order(num_dims);
+ std::vector<int64> new_kernel_dims(num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ new_kernel_dim_order[i] = dnums.kernel_spatial_dimensions(i);
+ new_kernel_dims[i] =
+ kernel->shape().dimensions(dnums.kernel_spatial_dimensions(i));
+ }
+ new_kernel_dim_order[num_dims - 2] = kernel_input_feature_dim;
+ new_kernel_dims[num_dims - 2] =
+ kernel->shape().dimensions(kernel_input_feature_dim);
+ new_kernel_dim_order[num_dims - 1] = kernel_output_feature_dim;
+ new_kernel_dims[num_dims - 1] =
+ kernel->shape().dimensions(kernel_output_feature_dim);
+
+ Shape new_kernel_shape =
+ ShapeUtil::MakeShape(kernel->shape().element_type(), new_kernel_dims);
+ HloInstruction* new_kernel = module->entry_computation()->AddInstruction(
+ HloInstruction::CreateTranspose(new_kernel_shape, kernel,
+ new_kernel_dim_order));
+
+ std::vector<int64> new_conv_dims(num_dims);
+ new_conv_dims[0] = hlo->shape().dimensions(batch_dim);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ new_conv_dims[i + 1] =
+ hlo->shape().dimensions(dnums.spatial_dimensions(i));
+ }
+ new_conv_dims[num_dims - 1] = hlo->shape().dimensions(feature_dim);
+ Shape new_conv_shape =
+ ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims);
+
+ ConvolutionDimensionNumbers new_dnums;
+ new_dnums.set_batch_dimension(0);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ new_dnums.add_spatial_dimensions(i + 1);
+ new_dnums.add_kernel_spatial_dimensions(i);
+ }
+ new_dnums.set_feature_dimension(num_dims - 1);
+ new_dnums.set_kernel_input_feature_dimension(num_dims - 2);
+ new_dnums.set_kernel_output_feature_dimension(num_dims - 1);
+
+ // The window of the old convolution is reused, because reshapes only
+ // change the dimension mapping but not the dimension sizes. For
+ // example, input height and width are the same as before the reshapes.
+ HloInstruction* new_conv = module->entry_computation()->AddInstruction(
+ HloInstruction::CreateConvolve(new_conv_shape, new_input, new_kernel,
+ hlo->window(), new_dnums));
+
+ // kConvolution inherits the dimension mapping of its input, so we need to
+ // reshape the output back to the shape of the original convolution. This
+ // is done by apply the inverse permutation of the collapsing order of the
+ // input reshape.
+ module->entry_computation()->ReplaceWithNewInstruction(
+ hlo,
+ HloInstruction::CreateTranspose(
+ hlo->shape(), new_conv, InversePermutation(new_input_dim_order)));
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
new file mode 100644
index 0000000000..57e17eb010
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.h
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+namespace cpu {
+
+// An HLO pass that canonicalizes the dimension numbers of all top-level
+// convolutions in the given module.
+//
+// In order to hit the fast path of using Eigen's convolution implementation, a
+// convolution's dimension numbers need to satisfy certain constraints (so
+// called canonical convolutions). This pass expands non-canonical convolutions
+// into reshapes and canonical convolutions, so that these non-canonical
+// convolutions can run faster.
+class ConvCanonicalization : public HloPass {
+ public:
+ ConvCanonicalization() : HloPass("convolution-canonicalization") {}
+ ~ConvCanonicalization() override {}
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CONV_CANONICALIZATION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
new file mode 100644
index 0000000000..d18141af83
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/util.h"
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+
+namespace xla {
+namespace cpu {
+
+class ConvCanonicalizationTest : public HloTestBase {
+ public:
+ ConvCanonicalizationTest() {
+ for (int i = 0; i < 2; ++i) {
+ auto dim = conv_window_.add_dimensions();
+ dim->set_size(kWindowSize);
+ dim->set_stride(1);
+ dim->set_padding_low(0);
+ dim->set_padding_high(0);
+ dim->set_window_dilation(1);
+ dim->set_base_dilation(1);
+ }
+ }
+
+ protected:
+ Window conv_window_;
+
+ static constexpr int kBatchSize = 50;
+ static constexpr int kInputSize = 28;
+ static constexpr int kWindowSize = 5;
+ static constexpr int kInputFeatureCount = 32;
+ static constexpr int kOutputFeatureCount = 64;
+};
+
+TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) {
+ auto builder = HloComputation::Builder(TestName());
+ // The input dimensions are in CNHW order.
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
+ kInputFeatureCount, kBatchSize, kInputSize, kInputSize))));
+ // The kernel dimensions are in OIHW order.
+ auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
+ kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize))));
+
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_batch_dimension(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.add_spatial_dimensions(3);
+ dnums.set_feature_dimension(0);
+ dnums.add_kernel_spatial_dimensions(2);
+ dnums.add_kernel_spatial_dimensions(3);
+ dnums.set_kernel_input_feature_dimension(1);
+ dnums.set_kernel_output_feature_dimension(0);
+ auto output_size = kInputSize - kWindowSize + 1;
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(
+ F32, {kOutputFeatureCount, kBatchSize, output_size, output_size}),
+ input, kernel, conv_window_, dnums));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+
+ ConvCanonicalization conv_canonicalization;
+ EXPECT_TRUE(conv_canonicalization.Run(module.get()).ValueOrDie());
+
+ const HloInstruction* output_reshape = entry_computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kTranspose, output_reshape->opcode());
+ const HloInstruction* canonical_conv = output_reshape->operand(0);
+ EXPECT_EQ(HloOpcode::kConvolution, canonical_conv->opcode());
+ const HloInstruction* input_reshape = canonical_conv->operand(0);
+ EXPECT_EQ(HloOpcode::kTranspose, input_reshape->opcode());
+ const HloInstruction* kernel_reshape = canonical_conv->operand(1);
+ EXPECT_EQ(HloOpcode::kTranspose, kernel_reshape->opcode());
+
+ // The input is in CNHW order. input_reshape should produce
+ // NHWC for the convolution to hit the Eigen fast path.
+ EXPECT_TRUE(ContainersEqual(input_reshape->dimensions(),
+ std::vector<int64>({1, 2, 3, 0})));
+ // The kernel is in OIHW order. kernel_reshape should produce
+ // HWIO for the convolution to hit the Eigen fast path.
+ EXPECT_TRUE(ContainersEqual(kernel_reshape->dimensions(),
+ std::vector<int64>({2, 3, 1, 0})));
+ // The output of the canonical convolution is in NHWC order (the same as
+ // input_reshape's order). output_reshape should restore that order to the
+ // order of the computation root (CNHW).
+ EXPECT_TRUE(ContainersEqual(output_reshape->dimensions(),
+ std::vector<int64>({3, 0, 1, 2})));
+}
+
+TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) {
+ auto builder = HloComputation::Builder(TestName());
+ // The input dimensions are in NHWC order.
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
+ kBatchSize, kInputSize, kInputSize, kInputFeatureCount))));
+ // The kernel dimensions are in HWIO order.
+ auto kernel = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR4FromArray4D(Array4D<float>(
+ kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount))));
+
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+ auto output_size = kInputSize - kWindowSize + 1;
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(
+ F32, {kBatchSize, output_size, output_size, kOutputFeatureCount}),
+ input, kernel, conv_window_, dnums));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+
+ ConvCanonicalization conv_canonicalization;
+ EXPECT_FALSE(conv_canonicalization.Run(module.get()).ValueOrDie());
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
new file mode 100644
index 0000000000..d566cfd8c8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -0,0 +1,631 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+
+#include <stddef.h>
+#include <string.h>
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+// IWYU pragma: no_include "llvm/Config/Disassemblers.def.inc"
+// IWYU pragma: no_include "llvm/Config/Targets.def.inc"
+#include "external/llvm/include/llvm/ADT/StringRef.h"
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/Object/ObjectFile.h"
+#include "external/llvm/include/llvm/Support/CommandLine.h"
+#include "external/llvm/include/llvm/Support/TargetRegistry.h"
+#include "external/llvm/include/llvm/Support/TargetSelect.h"
+#include "external/llvm/include/llvm/Target/TargetMachine.h"
+#include "external/llvm/include/llvm/Target/TargetOptions.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/port/initialize.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
+#include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h"
+#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
+#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h"
+#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h"
+#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/transpose_folding.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace cpu {
+
+CpuAotCompilationOptions::CpuAotCompilationOptions(
+ string triple, string cpu_name, string features, string entry_point_name,
+ RelocationModel relocation_model)
+ : triple_(std::move(triple)),
+ cpu_name_(std::move(cpu_name)),
+ features_(std::move(features)),
+ entry_point_name_(std::move(entry_point_name)),
+ relocation_model_(relocation_model) {}
+
+CpuAotCompilationOptions::~CpuAotCompilationOptions() = default;
+
+se::Platform::Id CpuAotCompilationOptions::PlatformId() const {
+ return se::host::kHostPlatformId;
+}
+
+CpuAotCompilationResult::CpuAotCompilationResult(
+ ObjectFileData object_file_data, BufferSizes buffer_sizes,
+ int64 result_buffer_index)
+ : object_file_data_(std::move(object_file_data)),
+ buffer_sizes_(std::move(buffer_sizes)),
+ result_buffer_index_(result_buffer_index) {}
+
+CpuAotCompilationResult::~CpuAotCompilationResult() = default;
+
+CpuCompiler::CpuCompiler() {
+ // Initialize LLVM the first time the CpuCompiler is initialized.
+ static bool llvm_initialized = []() {
+ InitializeLLVMTarget();
+ return true;
+ }();
+ (void)llvm_initialized;
+}
+
+/* static */ void CpuCompiler::InitializeLLVMTarget() {
+ // Initialize LLVM's MC layer for the native target.
+ llvm::InitializeNativeTarget();
+ llvm::InitializeNativeTargetAsmPrinter();
+ LLVMInitializeX86Target();
+ LLVMInitializeX86TargetInfo();
+ LLVMInitializeX86TargetMC();
+ LLVMInitializeX86AsmPrinter();
+ LLVMInitializeX86Disassembler();
+ LLVMInitializeARMTarget();
+ LLVMInitializeARMTargetInfo();
+ LLVMInitializeARMTargetMC();
+ LLVMInitializeARMAsmPrinter();
+ LLVMInitializeARMDisassembler();
+ LLVMInitializeAArch64Target();
+ LLVMInitializeAArch64TargetInfo();
+ LLVMInitializeAArch64TargetMC();
+ LLVMInitializeAArch64AsmPrinter();
+ LLVMInitializeAArch64Disassembler();
+ LLVMInitializePowerPCTarget();
+ LLVMInitializePowerPCTargetInfo();
+ LLVMInitializePowerPCTargetMC();
+ LLVMInitializePowerPCAsmPrinter();
+ LLVMInitializePowerPCDisassembler();
+
+ // LLVM command-line flags are global, so set them during initialization.
+ legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();
+ if (!flags->xla_cpu_llvm_cl_opts.empty()) {
+ std::vector<string> opts =
+ tensorflow::str_util::Split(flags->xla_cpu_llvm_cl_opts, ',');
+ std::vector<const char*> fake_argv;
+ fake_argv.push_back("");
+ for (const string& opt : opts) {
+ fake_argv.push_back(opt.c_str());
+ }
+ llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
+ }
+}
+
+namespace {
+// This visitor records which HLO instructions should have profiling information
+// recorded.
+class CollectProfileCandidates : public DfsHloVisitorWithDefault {
+ public:
+ static StatusOr<std::unordered_map<const HloInstruction*, size_t>>
+ GetCandidatesForComputation(HloComputation* computation) {
+ std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
+ CollectProfileCandidates profile_candidates_for_computation(
+ &hlo_to_profile_idx);
+ TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
+ &profile_candidates_for_computation));
+ return hlo_to_profile_idx;
+ }
+
+ private:
+ explicit CollectProfileCandidates(
+ std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx)
+ : hlo_to_profile_idx_(hlo_to_profile_idx) {}
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()});
+ return Status::OK();
+ }
+ // Skip constants, there is nothing to profile.
+ Status HandleConstant(HloInstruction* /*constant*/,
+ const Literal& /*literal*/) override {
+ return Status::OK();
+ }
+ // Skip parameters, they are a simple load.
+ Status HandleParameter(HloInstruction* /*parameter*/) override {
+ return Status::OK();
+ }
+ // It is important to recurse for "while" or else we risk overly coarse
+ // profiling information.
+ Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/,
+ HloComputation* condition, HloComputation* body) override {
+ TF_RETURN_IF_ERROR(DefaultAction(xla_while));
+
+ CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_);
+ TF_RETURN_IF_ERROR(
+ condition->root_instruction()->Accept(&candidates_for_condition));
+
+ CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_);
+ TF_RETURN_IF_ERROR(body->root_instruction()->Accept(&candidates_for_body));
+
+ return Status::OK();
+ }
+
+ std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
+};
+} // namespace
+
+Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
+ HloModuleConfig* module_config,
+ HloDumper dump_hlo) {
+ // Optimization pipeline.
+ HloPassPipeline pipeline("CPU", dump_hlo);
+ pipeline.AddPass<Inliner>();
+ pipeline.AddPass<ConvCanonicalization>();
+ {
+ auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification",
+ dump_hlo);
+ pass.AddPass<AlgebraicSimplifier>(
+ /*is_layout_sensitive=*/false,
+ [](const Shape&, const Shape&) { return false; });
+ pass.AddPass<ReshapeMover>();
+ }
+ pipeline.AddPass<TransposeFolding>(PotentiallyImplementedAsEigenDot);
+ pipeline.AddPass<HloSubcomputationUnification>();
+ pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
+ pipeline.AddPass<CpuInstructionFusion>();
+ pipeline.AddPass<CpuLayoutAssignment>(
+ module_config->mutable_entry_computation_layout());
+ // The LayoutAssignment pass may leave behind kCopy instructions which are
+ // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ /*is_layout_sensitive=*/true,
+ [](const Shape&, const Shape&) { return true; });
+ pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ // Copy insertion should be performed immediately before IR emission to
+ // avoid inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes
+ // an instruction which materializes a value).
+ pipeline.AddPass<CopyInsertion>();
+ pipeline.AddPass<HloDCE>();
+ legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();
+ if (flags->xla_cpu_parallel) {
+ pipeline.AddPass<ParallelizationPreparation>();
+ }
+ return pipeline.Run(hlo_module).status();
+}
+
+namespace {
+
+llvm::TargetOptions CompilerTargetOptions() {
+ llvm::TargetOptions target_options;
+ llvm_ir::SetTargetOptions(&target_options);
+ return target_options;
+}
+
+llvm::CodeGenOpt::Level CodeGenOptLevel() {
+ legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();
+ switch (flags->xla_cpu_llvm_opt_level) {
+ case 1:
+ return llvm::CodeGenOpt::Less;
+ case 2:
+ return llvm::CodeGenOpt::Default;
+ break;
+ case 3:
+ return llvm::CodeGenOpt::Aggressive;
+ break;
+ default:
+ return llvm::CodeGenOpt::None;
+ }
+}
+
+// Constructs and returns a sequence for the HLO instructions in each
+// computation in the given module. The sequence can be used to determine the
+// order of HLO instruction emission and for buffer liveness analysis.
+SequentialHloOrdering::HloModuleSequence CreateModuleSequence(
+ const HloModule* module) {
+ SequentialHloOrdering::HloModuleSequence sequence;
+ for (auto& computation : module->computations()) {
+ // Do a DFS traversal from the root to construct a sequence for each
+ // computation.
+ // TODO(b/32006145): Construct a sequence to minimize memory pressure.
+ std::vector<const HloInstruction*> order;
+ TF_CHECK_OK(computation->root_instruction()->Accept(
+ [&order](HloInstruction* instruction) {
+ order.push_back(instruction);
+ return Status::OK();
+ }));
+ sequence.emplace(computation.get(), std::move(order));
+ }
+ return sequence;
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ se::StreamExecutor* stream_exec) {
+ TF_RET_CHECK(stream_exec != nullptr);
+
+ // Compile must be thread-safe so create a new LLVM context for the module.
+ auto llvm_context = MakeUnique<llvm::LLVMContext>();
+ auto llvm_module =
+ MakeUnique<llvm::Module>("__compute_module", *llvm_context);
+ auto jit =
+ MakeUnique<SimpleOrcJIT>(CompilerTargetOptions(), CodeGenOptLevel());
+ llvm_module->setDataLayout(jit->data_layout());
+ llvm_module->setTargetTriple(jit->target_triple().getTriple());
+ const llvm::DataLayout& data_layout = llvm_module->getDataLayout();
+ int64 pointer_size = data_layout.getPointerSize();
+
+ TF_RETURN_IF_ERROR(
+ RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo));
+
+ HloComputation* computation = hlo_module->entry_computation();
+ std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
+ if (module_config->hlo_profiling_enabled()) {
+ TF_ASSIGN_OR_RETURN(
+ hlo_to_profile_idx,
+ CollectProfileCandidates::GetCandidatesForComputation(computation));
+ }
+
+ std::unique_ptr<Executable> cpu_executable;
+ legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();
+ if (flags->xla_cpu_parallel) {
+ // Run buffer analysis on the HLO graph. This analysis figures out which
+ // temporary buffers are required to run the computation.
+ // DependencyHloOrdering is used for the parallel emitter because the order
+ // of HLO instruction execution is not known ahead of time.
+ // DependencyHloOrdering is the most conservative partial order and only
+ // uses data dependencies for determining order.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferAssignment> assignment,
+ BufferAssigner::Run(hlo_module.get(),
+ MakeUnique<DependencyHloOrdering>(hlo_module.get()),
+ pointer_size));
+
+ // If we are using the parallel CPU backend, we need to create map from
+ // HloInstruction to the corresponding generated function name.
+ std::map<HloComputation*, HloInstruction*> parallel_computations;
+ std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
+ aligned_constants;
+ for (auto instruction : computation->MakeInstructionPostOrder()) {
+ // Parameters and constants don't get their own computation.
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ continue;
+ }
+ if (instruction->opcode() == HloOpcode::kConstant) {
+ // Copy the constant out of the ProtocolBuffer so that we can give it a
+ // higher alignment.
+ const void* data = LiteralUtil::InternalData(instruction->literal());
+ int64 size = llvm_ir::ByteSizeOf(instruction->shape(), data_layout);
+ auto iter = aligned_constants.emplace(
+ instruction, MakeUnique<unsigned char[]>(size));
+ CHECK_EQ(iter.second, true);
+ unsigned char* aligned_data = iter.first->second.get();
+ memcpy(aligned_data, data, size);
+ continue;
+ }
+ // The parallel preparation should have ensured that the top-level
+ // computation consists solely of Call instructions.
+ TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall);
+ HloComputation* to_apply = instruction->to_apply();
+ parallel_computations.emplace(to_apply, instruction);
+ }
+
+ IrEmitter ir_emitter(*hlo_module, *module_config, *assignment,
+ llvm_module.get(), &hlo_to_profile_idx);
+ std::unique_ptr<std::map<HloInstruction*, string>> function_names(
+ new std::map<HloInstruction*, string>());
+ for (auto embedded_computation :
+ computation->MakeEmbeddedComputationsList()) {
+ auto parallel_computation_iter =
+ parallel_computations.find(embedded_computation);
+ // All parallel computations are considered to be an entry computation for
+ // IR generation purposes.
+ bool computation_is_parallel =
+ parallel_computation_iter != parallel_computations.end();
+ TF_ASSIGN_OR_RETURN(
+ llvm::Function * ir_function,
+ ir_emitter.EmitComputation(
+ embedded_computation, embedded_computation->name(),
+ /*is_entry_computation=*/computation_is_parallel));
+ // If this computation is parallel, remember it in the function name map.
+ // This way we know what function to execute when we try to run code for
+ // the Call instruction.
+ if (computation_is_parallel) {
+ HloInstruction* call_instruction = parallel_computation_iter->second;
+ InsertOrDie(function_names.get(), call_instruction,
+ llvm_ir::AsString(ir_function->getName()));
+ }
+ }
+
+ string ir_module_string;
+ if (flags->xla_cpu_embed_ir) {
+ ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
+ }
+
+ // JIT compile the LLVM IR module to in-memory machine code.
+ jit->AddModule(std::move(llvm_module));
+ cpu_executable.reset(new ParallelCpuExecutable(
+ std::move(jit), std::move(assignment), std::move(hlo_module),
+ std::move(module_config), std::move(function_names),
+ std::move(hlo_to_profile_idx), std::move(aligned_constants)));
+
+ if (flags->xla_cpu_embed_ir) {
+ static_cast<CpuExecutable&>(*cpu_executable)
+ .set_ir_module_string(ir_module_string);
+ }
+ } else {
+ // Select an order for emitting the HLO instructions for each
+ // computation. Using this sequence enables tighter buffer liveness analysis
+ // and reduced memory usage (as compared to using DependencyHloOrdering).
+ SequentialHloOrdering::HloModuleSequence module_sequence =
+ CreateModuleSequence(hlo_module.get());
+
+ // Run buffer analysis on the HLO graph. This analysis figures out which
+ // temporary buffers are required to run the computation.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferAssignment> assignment,
+ BufferAssigner::Run(hlo_module.get(),
+ MakeUnique<SequentialHloOrdering>(hlo_module.get(),
+ module_sequence),
+ pointer_size));
+
+ // Each computation is a single function. Emit all embedded computations
+ // before the entry computation. The order of computations returned from
+ // GetEmbeddedComputations guarantees that a called computation occurs
+ // before a caller computation.
+ IrEmitter ir_emitter(*hlo_module, *module_config, *assignment,
+ llvm_module.get(), &hlo_to_profile_idx);
+ for (auto embedded_computation :
+ computation->MakeEmbeddedComputationsList()) {
+ TF_RETURN_IF_ERROR(
+ ir_emitter
+ .EmitComputation(embedded_computation,
+ embedded_computation->name(),
+ /*is_entry_computation=*/false,
+ &module_sequence.at(embedded_computation))
+ .status());
+ }
+ string function_name_prefix =
+ computation->name().empty() ? "__compute" : computation->name();
+ TF_ASSIGN_OR_RETURN(
+ llvm::Function * entry_function,
+ ir_emitter.EmitComputation(computation, function_name_prefix,
+ /*is_entry_computation=*/true,
+ &module_sequence.at(computation)));
+
+ string function_name = llvm_ir::AsString(entry_function->getName());
+ string ir_module_string;
+ if (flags->xla_cpu_embed_ir) {
+ ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
+ }
+
+ // JIT compile the LLVM IR module to in-memory machine code.
+ jit->AddModule(std::move(llvm_module));
+ cpu_executable.reset(
+ new CpuExecutable(std::move(jit), std::move(assignment),
+ std::move(hlo_module), std::move(module_config),
+ function_name, std::move(hlo_to_profile_idx)));
+
+ if (flags->xla_cpu_embed_ir) {
+ static_cast<CpuExecutable&>(*cpu_executable)
+ .set_ir_module_string(ir_module_string);
+ }
+ }
+
+ return std::move(cpu_executable);
+}
+
+StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) {
+ return Unimplemented(
+ "Compilation of multiple HLO modules is not yet supported on CPU.");
+}
+
+StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::CompileAheadOfTime(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ const AotCompilationOptions& aot_options) {
+ if (aot_options.PlatformId() != se::host::kHostPlatformId) {
+ return InvalidArgument("Incompatible AOT compilation platform");
+ }
+ const CpuAotCompilationOptions& options =
+ static_cast<const CpuAotCompilationOptions&>(aot_options);
+ llvm::StringRef target_triple = llvm_ir::AsStringRef(options.triple());
+ llvm::Triple triple(llvm::Triple::normalize(target_triple));
+ std::string error;
+ const llvm::Target* target =
+ llvm::TargetRegistry::lookupTarget(triple.getTriple(), error);
+ if (target == nullptr) {
+ return InternalError("TargetRegistry::lookupTarget failed: %s",
+ error.c_str());
+ }
+
+ llvm::Reloc::Model reloc_model;
+ llvm::PICLevel::Level pic_level;
+ llvm::PIELevel::Level pie_level;
+ switch (options.relocation_model()) {
+ case CpuAotCompilationOptions::RelocationModel::Static:
+ reloc_model = llvm::Reloc::Static;
+ pic_level = llvm::PICLevel::NotPIC;
+ pie_level = llvm::PIELevel::Default;
+ break;
+ case CpuAotCompilationOptions::RelocationModel::SmallPic:
+ reloc_model = llvm::Reloc::PIC_;
+ pic_level = llvm::PICLevel::SmallPIC;
+ pie_level = llvm::PIELevel::Default;
+ break;
+ case CpuAotCompilationOptions::RelocationModel::BigPic:
+ reloc_model = llvm::Reloc::PIC_;
+ pic_level = llvm::PICLevel::BigPIC;
+ pie_level = llvm::PIELevel::Default;
+ break;
+ case CpuAotCompilationOptions::RelocationModel::SmallPie:
+ reloc_model = llvm::Reloc::PIC_;
+ pic_level = llvm::PICLevel::SmallPIC;
+ pie_level = llvm::PIELevel::Small;
+ break;
+ case CpuAotCompilationOptions::RelocationModel::BigPie:
+ reloc_model = llvm::Reloc::PIC_;
+ pic_level = llvm::PICLevel::BigPIC;
+ pie_level = llvm::PIELevel::Large;
+ break;
+ }
+ llvm::StringRef cpu_name = llvm_ir::AsStringRef(options.cpu_name());
+ llvm::StringRef features = llvm_ir::AsStringRef(options.features());
+ llvm::CodeGenOpt::Level opt_level = CodeGenOptLevel();
+ std::unique_ptr<llvm::TargetMachine> target_machine =
+ WrapUnique(target->createTargetMachine(
+ triple.getTriple(), cpu_name, features, CompilerTargetOptions(),
+ reloc_model, llvm::CodeModel::Default, opt_level));
+
+ // Compile must be thread-safe so create a new LLVM context for the module.
+ llvm::LLVMContext llvm_context;
+ llvm::Module llvm_module("__compute_module", llvm_context);
+ llvm_module.setDataLayout(target_machine->createDataLayout());
+ llvm_module.setTargetTriple(triple.getTriple());
+ if (pic_level != llvm::PICLevel::NotPIC) {
+ llvm_module.setPICLevel(pic_level);
+ }
+ if (pie_level != llvm::PIELevel::Default) {
+ llvm_module.setPIELevel(pie_level);
+ }
+ const llvm::DataLayout& data_layout = llvm_module.getDataLayout();
+ int64 pointer_size = data_layout.getPointerSize();
+
+ TF_RETURN_IF_ERROR(
+ RunHloPasses(hlo_module.get(), module_config.get(), dump_hlo));
+
+ SequentialHloOrdering::HloModuleSequence module_sequence =
+ CreateModuleSequence(hlo_module.get());
+ // Run buffer analysis on the HLO graph. This analysis figures out which
+ // temporary buffers are required to run the computation.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferAssignment> assignment,
+ BufferAssigner::Run(
+ hlo_module.get(),
+ MakeUnique<SequentialHloOrdering>(hlo_module.get(), module_sequence),
+ pointer_size));
+
+ IrEmitter ir_emitter(*hlo_module, *module_config, *assignment, &llvm_module,
+ /*hlo_to_profile_idx=*/nullptr);
+ HloComputation* computation = hlo_module->entry_computation();
+ for (auto embedded_computation :
+ computation->MakeEmbeddedComputationsList()) {
+ TF_RETURN_IF_ERROR(
+ ir_emitter
+ .EmitComputation(embedded_computation, embedded_computation->name(),
+ /*is_entry_computation=*/false,
+ &module_sequence.at(embedded_computation))
+ .status());
+ }
+ const string& entry_point_name = options.entry_point_name();
+ TF_ASSIGN_OR_RETURN(
+ llvm::Function * entry_function,
+ ir_emitter.EmitComputation(computation, entry_point_name,
+ /*is_entry_computation=*/true));
+
+ entry_function->setName(llvm_ir::AsStringRef(entry_point_name));
+
+ Disassembler disassembler(*target_machine);
+ CompilerFunctor compiler_functor(target_machine.get(), &disassembler,
+ opt_level, CompilerFunctor::AllIntrinsics());
+ llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
+ compiler_functor(llvm_module);
+ llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
+ ObjectFileData object_file_data(object_file_data_ref.begin(),
+ object_file_data_ref.end());
+
+ BufferSizes buffer_sizes;
+ for (const BufferAllocation& allocation : assignment->Allocations()) {
+ // Callers don't need to allocate temporary buffers for parameters.
+ if (allocation.is_entry_computation_parameter()) {
+ buffer_sizes.push_back(-1);
+ continue;
+ }
+ // Callers don't need to allocate anything for thread-local temporary
+ // buffers. They are lowered to allocas.
+ if (allocation.is_thread_local()) {
+ buffer_sizes.push_back(-1);
+ continue;
+ }
+ buffer_sizes.push_back(allocation.size());
+ }
+
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
+ assignment->GetUniqueTopLevelOutputAllocation());
+
+ return std::unique_ptr<AotCompilationResult>(
+ MakeUnique<CpuAotCompilationResult>(std::move(object_file_data),
+ std::move(buffer_sizes),
+ result_allocation->index()));
+}
+
+se::Platform::Id CpuCompiler::PlatformId() const {
+ return se::host::kHostPlatformId;
+}
+
+} // namespace cpu
+} // namespace xla
+
+REGISTER_MODULE_INITIALIZER(cpu_compiler, {
+ xla::Compiler::RegisterCompilerFactory(se::host::kHostPlatformId, []() {
+ return xla::MakeUnique<xla::cpu::CpuCompiler>();
+ });
+});
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
new file mode 100644
index 0000000000..349724d840
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h
@@ -0,0 +1,148 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace cpu {
+
+// This class wraps the configurability options that LLVM exposes including: the
+// target triple, the target cpu and the target features. It also includes the
+// desired linkage name for the computation entry point.
+// Note that the optimization level can be controlled by the
+// --xla_cpu_llvm_opt_level flag.
+class CpuAotCompilationOptions : public AotCompilationOptions {
+ public:
+ // Relocation models available for compilation.
+ enum class RelocationModel {
+ // Corresponds to the -fno-pic compiler option.
+ Static,
+ // Corresponds to the -fpic compiler option.
+ SmallPic,
+ // Corresponds to the -fPIC compiler option.
+ BigPic,
+ // Corresponds to the -fpie compiler option.
+ SmallPie,
+ // Corresponds to the -fPIE compiler option.
+ BigPie
+ };
+
+ CpuAotCompilationOptions(string triple, string cpu_name, string features,
+ string entry_point_name,
+ RelocationModel relocation_model);
+ ~CpuAotCompilationOptions() override;
+
+ perftools::gputools::Platform::Id PlatformId() const override;
+
+ // The triple used for compilation, similar to clang's -target flag.
+ const string& triple() const { return triple_; }
+ // The CPU name used for compilation, similar to clang's -mcpu flag.
+ const string& cpu_name() const { return cpu_name_; }
+ // The target features used for compilation ("+avx2", "+neon", etc).
+ const string& features() const { return features_; }
+ // The name to be used for the compiled code's entry point.
+ const string& entry_point_name() const { return entry_point_name_; }
+ // The relocation model used for compilation.
+ RelocationModel relocation_model() const { return relocation_model_; }
+
+ private:
+ const string triple_;
+ const string cpu_name_;
+ const string features_;
+ const string entry_point_name_;
+ const RelocationModel relocation_model_;
+};
+
+class CpuAotCompilationResult : public AotCompilationResult {
+ public:
+ CpuAotCompilationResult(ObjectFileData object_file_data,
+ BufferSizes buffer_sizes, int64 result_buffer_index);
+ ~CpuAotCompilationResult();
+
+ const ObjectFileData& object_file_data() const { return object_file_data_; }
+ const BufferSizes& buffer_sizes() const { return buffer_sizes_; }
+ int64 result_buffer_index() const { return result_buffer_index_; }
+
+ private:
+ // Contains the compiled computation: an object file.
+ const ObjectFileData object_file_data_;
+
+ // The list of buffer sizes which should be allocated in order to execute the
+ // compiled computation. These buffers are used for temporary buffers used
+ // ephemerally during computation as well as the output result.
+ const BufferSizes buffer_sizes_;
+
+ // Contains which buffer index into |buffer_sizes| was designated to the
+ // result of the computation. This buffer should be passed into the output
+ // parameter when calling the compiled computation.
+ const int64 result_buffer_index_;
+};
+
+// CPU-targeting implementation of the XLA Compiler interface.
+//
+// The compiler translates XLA HLO code into LLVM IR and uses LLVM's JIT
+// infrastructure to create an executable "blob" that can then be returned
+// wrapped in CpuExecutable and actually invoked.
+class CpuCompiler : public Compiler {
+ public:
+ CpuCompiler();
+ ~CpuCompiler() override {}
+
+ StatusOr<std::unique_ptr<Executable>> Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ perftools::gputools::StreamExecutor* stream_exec) override;
+
+ StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_module,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_config,
+ HloDumper dump_hlo,
+ std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
+
+ StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ const AotCompilationOptions& options) override;
+
+ perftools::gputools::Platform::Id PlatformId() const override;
+
+ private:
+ // Initialize the LLVM target.
+ static void InitializeLLVMTarget();
+
+ // Runs the HLO passes which are necessary for both optimizations and
+ // correctness.
+ Status RunHloPasses(HloModule* hlo_module, HloModuleConfig* module_config,
+ HloDumper dump_hlo);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler);
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_COMPILER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
new file mode 100644
index 0000000000..727257d4f1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -0,0 +1,477 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h"
+
+#include <stdint.h>
+#include <algorithm>
+#include <set>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace cpu {
+
+CpuExecutable::CpuExecutable(
+ std::unique_ptr<SimpleOrcJIT> jit,
+ std::unique_ptr<BufferAssignment> assignment,
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ const string& entry_function_name,
+ std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx)
+ : Executable(std::move(hlo_module), std::move(module_config)),
+ jit_(std::move(jit)),
+ assignment_(std::move(assignment)),
+ hlo_to_profile_idx_(std::move(hlo_to_profile_idx)) {
+ // Resolve symbols in the constructor rather than at execution time to avoid
+ // races because FindSymbol is not thread safe.
+ llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name);
+ // We expect to find the symbol provided with entry_function_name; otherwise
+ // this is an internal error.
+ CHECK(sym) << "Symbol " << entry_function_name << " not found.";
+ // getAddress can do work under the hood in the jit, so it needs to be
+ // guarded by the mutex.
+ compute_function_ = reinterpret_cast<ComputeFunctionType>(sym.getAddress());
+}
+
+// Given a pointer to an output buffer (following the CPU JIT calling
+// conventions), mark addresses that are "live". The initial pointer itself is
+// trivially live. If the shape of the buffer is a tuple, this analysis looks
+// into the tuple's elements and marks them live as well (since tuples keep
+// pointers to buffers) and also works recursively. address is an in-memory
+// buffer address that contains some runtime XLA object. shape is its
+// shape. marked_addresses is the set of live addresses to populate.
+static void MarkLiveAddressesInOutput(
+ const void* address, const Shape& shape,
+ std::unordered_set<const void*>* marked_addresses) {
+ marked_addresses->insert(address);
+ const uintptr_t* address_buffer = static_cast<const uintptr_t*>(address);
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const uintptr_t* element_address = address_buffer + i;
+ const void* element = reinterpret_cast<const void*>(*element_address);
+ MarkLiveAddressesInOutput(
+ element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses);
+ }
+ }
+}
+
+Status CpuExecutable::AllocateBuffers(
+ DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ std::vector<perftools::gputools::DeviceMemoryBase>* buffers) {
+ CHECK_EQ(buffers->size(), assignment_->Allocations().size());
+ VLOG(3) << "Allocating " << assignment_->Allocations().size()
+ << " allocations for module " << module().name();
+ for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+ ++i) {
+ auto& allocation = assignment_->GetAllocation(i);
+
+ VLOG(3) << allocation.ToString();
+
+ if (allocation.is_entry_computation_parameter()) {
+ VLOG(3) << "allocation #" << i << " is a parameter";
+ continue;
+ }
+
+ if (allocation.is_thread_local()) {
+ VLOG(3) << "buffer #" << i << " is thread-local";
+ continue;
+ }
+
+ int64 buffer_size = allocation.size();
+ if (!(*buffers)[i].is_null()) {
+ VLOG(3) << "buffer #" << i
+ << " is in the preallocated result ShapedBuffer";
+ } else {
+ TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
+ device_ordinal, buffer_size));
+
+ VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
+ << (*buffers)[i].opaque() << "]";
+ }
+
+ // Since the output buffer and all the temporary buffers were written into
+ // by the JITed code, msan has no way of knowing their memory was
+ // initialized. Mark them initialized so that msan doesn't flag loads from
+ // these buffers.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
+ }
+
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
+ assignment_->GetUniqueTopLevelOutputAllocation());
+
+ VLOG(3) << "result index: " << result_allocation->index();
+
+ return Status::OK();
+}
+
+Status CpuExecutable::ExecuteComputeFunction(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+ HloExecutionProfile* hlo_execution_profile) {
+ std::vector<se::DeviceMemoryBase> argument_buffers;
+ for (int i = 0; i < arguments.size(); ++i) {
+ TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape()));
+ argument_buffers.push_back(arguments[i]->buffer(/*index=*/{}));
+ }
+ return ExecuteComputeFunction(run_options, argument_buffers, buffers,
+ hlo_execution_profile);
+}
+
+Status CpuExecutable::ExecuteComputeFunction(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+ HloExecutionProfile* hlo_execution_profile) {
+ // The calling convention for JITed functions is:
+ //
+ // void function(void* result, const void* run_options, void** args_array,
+ // void** temps_array)
+ //
+ // result: Points at the result.
+ // run_options: the ExecutableRunOptions object.
+ // args_array: An array of pointers, each of which points to a parameter.
+ // The size of this array is determined by the function's arity
+ // (ProgramShape).
+ // temps_array: An array of pointers, each of which points to a temporary
+ // buffer the computation needs. The size of this array is
+ // determined by buffer analysis.
+ //
+ std::vector<const void*> args_array;
+ for (se::DeviceMemoryBase arg_mem : arguments) {
+ args_array.push_back(arg_mem.opaque());
+ }
+
+ uint64 start_micros = tensorflow::Env::Default()->NowMicros();
+
+ // Allocate profiling counters for each hlo instruction that we would like to
+ // profile. Allocate an additional profile counter for the entire
+ // computation.
+ std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
+
+ // Call the computation function following the calling convention.
+ std::vector<void*> buffer_pointers;
+ for (auto& buffer : buffers) {
+ buffer_pointers.push_back(const_cast<void*>(buffer.opaque()));
+ }
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
+ assignment_->GetUniqueTopLevelOutputAllocation());
+ void* result_buffer = buffer_pointers[result_allocation->index()];
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << "Executing compute function:";
+ VLOG(3) << tensorflow::strings::Printf(
+ " func(void* result, void* params[%zu], void* temps[%zu], "
+ "uint64 profile_counters[%zu])",
+ args_array.size(), buffer_pointers.size(), profile_counters.size());
+ VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer);
+ auto ptr_printer = [](string* out, const void* p) {
+ tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p));
+ };
+ VLOG(3) << tensorflow::strings::Printf(
+ " params = [%s]",
+ tensorflow::str_util::Join(args_array, ", ", ptr_printer).c_str());
+ VLOG(3) << tensorflow::strings::Printf(
+ " temps = [%s]",
+ tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str());
+ VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p",
+ profile_counters.data());
+ }
+
+ compute_function_(result_buffer, run_options, args_array.data(),
+ buffer_pointers.data(), profile_counters.data());
+
+ uint64 end_micros = tensorflow::Env::Default()->NowMicros();
+
+ {
+ tensorflow::mutex_lock lock(mutex_);
+ const double nanoseconds = (end_micros - start_micros) * 1000.0;
+ execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+
+ // The last profile counter is used for the computation as a whole.
+ execution_profile_.set_compute_cycle_count(profile_counters.back());
+ }
+
+ if (hlo_execution_profile != nullptr) {
+ hlo_execution_profile->set_total_cycles_executed(profile_counters.back());
+
+ for (auto hlo_prof_idx : hlo_to_profile_idx_) {
+ const HloInstruction* hlo = hlo_prof_idx.first;
+ uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
+ hlo_execution_profile->AddProfileResult(hlo, cycles_taken);
+ }
+ }
+ return Status::OK();
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase> CpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
+ TF_RETURN_IF_ERROR(AllocateBuffers(
+ memory_allocator, stream->parent()->device_ordinal(), &buffers));
+
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers,
+ hlo_execution_profile));
+
+ // Mark the buffers that are actually live (used in the output) when the
+ // computation finishes executing.
+ std::unordered_set<const void*> marked_addresses;
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
+ assignment_->GetUniqueTopLevelOutputAllocation());
+ se::DeviceMemoryBase top_level_output = buffers[result_allocation->index()];
+ MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(),
+ &marked_addresses);
+
+ VLOG(3) << "Live addresses in output marking found "
+ << marked_addresses.size() << " addresses:\n"
+ << tensorflow::str_util::Join(
+ marked_addresses, ", ", [](string* out, const void* address) {
+ tensorflow::strings::StrAppend(
+ out, tensorflow::strings::Printf("%p", address));
+ });
+
+ // Computation is done - deallocate temp buffers. Keep those marked live
+ // because they are referenced by the output of the computation and are needed
+ // by the service. They will be deallocated by the service.
+ for (auto i = 0; i < buffers.size(); ++i) {
+ auto alloc = buffers[i];
+ if (marked_addresses.count(alloc.opaque()) == 0 &&
+ alloc.opaque() != nullptr) {
+ VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
+ << alloc.opaque() << "]";
+ TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
+ stream->parent()->device_ordinal(), &alloc));
+ }
+ }
+
+ return top_level_output;
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> CpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ return Unimplemented("Points-to set of root instruction is ambiguous");
+ }
+ std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<ShapedBuffer> result_buffer,
+ ShapedBuffer::MakeShapedBuffer(
+ module_config().entry_computation_layout().result_shape(),
+ stream->parent()->platform(), stream->parent()->device_ordinal()));
+
+ TF_RETURN_IF_ERROR(AllocateBuffers(
+ memory_allocator, stream->parent()->device_ordinal(), &buffers));
+
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers,
+ hlo_execution_profile));
+
+ // Copy DeviceMemoryBase values which contain the array(s) of the result into
+ // the respective location in ShapedBuffer which is returned to the caller.
+ std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
+ TF_RETURN_IF_ERROR(
+ result_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [&buffers, &buffers_in_result, &result_buffer, this](
+ const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
+ if (is_leaf) {
+ const std::vector<const LogicalBuffer*>& sources =
+ this->GetRootPointsToSet().element(index);
+ // The points to set is unambiguous so the set should be a
+ // singleton.
+ CHECK_EQ(1, sources.size());
+ const LogicalBuffer* buffer_source = sources[0];
+ HloInstruction* src = buffer_source->instruction();
+
+ // The source for this result buffer can be a nested buffer
+ // such as a tuple element.
+
+ // The source instruction should have a non-parameter buffer
+ // assigned.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation,
+ this->assignment_->GetUniqueAllocation(
+ src, buffer_source->index()));
+ CHECK(!allocation->is_entry_computation_parameter());
+
+ CHECK(!buffers[allocation->index()].is_null() ||
+ buffers[allocation->index()].size() == 0);
+ result_buffer->mutable_buffers()->push_back(
+ buffers[allocation->index()]);
+ *buffer_entry = result_buffer->mutable_buffers()->size() - 1;
+ buffers_in_result[allocation->index()] = true;
+ }
+ return Status::OK();
+ }));
+
+ // Free all buffers not in the result.
+ for (auto i = 0; i < buffers.size(); ++i) {
+ auto alloc = buffers[i];
+ if (!buffers_in_result[i] && !alloc.is_null()) {
+ VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
+ << alloc.opaque() << "]";
+ TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
+ stream->parent()->device_ordinal(), &alloc));
+ }
+ }
+
+ return std::move(result_buffer);
+}
+
+Status CpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ // Every array element in the result of the computation must be unambiguously
+ // produced by a single instruction.
+ // This ensures that the buffers inside result_buffer can be assigned without
+ // conflict to the respective instructions because there is a one-to-one
+ // correspondence between hlo instructions and array buffers in the result.
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ return Unimplemented(
+ "Points-to set of root instruction is ambiguous or not distinct");
+ }
+ std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
+ DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape()));
+
+ // If two tuple elements point to the same buffer, one of the results in the
+ // result buffer is considered the canonical location while the other result
+ // points to it (instead of, say, making a copy of the result).
+ // buffer_index_to_shape_index maps a buffer index to its canonical location
+ // in the result buffer.
+ std::unordered_map<BufferAllocation::Index, size_t>
+ buffer_index_to_shape_index;
+
+ // Copy values from result_buffer to the index in "buffers". These buffers
+ // will not be allocated in the call to AllocateBuffers.
+ std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
+ TF_RETURN_IF_ERROR(
+ result_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [&buffers, &buffers_in_result, &buffer_index_to_shape_index,
+ result_buffer, this](const ShapeIndex& index, bool is_leaf,
+ size_t* buffer_entry) {
+ if (is_leaf) {
+ const std::vector<const LogicalBuffer*>& sources =
+ this->GetRootPointsToSet().element(index);
+ // The points to set is unambiguous so the set should be a
+ // singleton.
+ CHECK_EQ(1, sources.size());
+ const LogicalBuffer* buffer_source = sources[0];
+ HloInstruction* src = buffer_source->instruction();
+
+ // The source for this result buffer can be a nested buffer
+ // such as a tuple element.
+
+ // The source instruction should have a non-parameter buffer
+ // assigned.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation,
+ this->assignment_->GetUniqueAllocation(
+ src, buffer_source->index()));
+ CHECK(!allocation->is_entry_computation_parameter());
+
+ auto insert_result = buffer_index_to_shape_index.emplace(
+ allocation->index(), *buffer_entry);
+ if (insert_result.second) {
+ // The points-to set is distinct so this buffer should not
+ // have
+ // been assigned in a previous invocation of this lambda.
+ perftools::gputools::DeviceMemoryBase memory_base =
+ result_buffer->buffer(index);
+ CHECK(buffers[allocation->index()].is_null());
+ CHECK(!memory_base.is_null());
+ buffers[allocation->index()] = memory_base;
+ buffers_in_result[allocation->index()] = true;
+ } else {
+ // Record the fact that this tuple element is identical to
+ // some
+ // prior result.
+ *buffer_entry = insert_result.first->second;
+ }
+ }
+ return Status::OK();
+ }));
+
+ TF_RETURN_IF_ERROR(AllocateBuffers(
+ memory_allocator, stream->parent()->device_ordinal(), &buffers));
+
+ TF_RETURN_IF_ERROR(ExecuteComputeFunction(run_options, arguments, buffers,
+ hlo_execution_profile));
+
+ // Free all buffers not in the result.
+ for (auto i = 0; i < buffers.size(); ++i) {
+ auto alloc = buffers[i];
+ if (!buffers_in_result[i] && !alloc.is_null()) {
+ VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
+ << alloc.opaque() << "]";
+ TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
+ stream->parent()->device_ordinal(), &alloc));
+ }
+ }
+
+ return Status::OK();
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase>
+CpuExecutable::ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ // TODO(b/30671675): Implement asynchronous execution mode.
+ return Unimplemented(
+ "Asynchronous execution on stream is not yet supported on CPU.");
+}
+
+const PointsToSet& CpuExecutable::GetRootPointsToSet() const {
+ return assignment_->points_to_analysis().GetPointsToSet(
+ module().entry_computation()->root_instruction());
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
new file mode 100644
index 0000000000..8f3247e683
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -0,0 +1,150 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
+
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace cpu {
+
+// CPU-targeting implementation of the XLA Executable interface.
+//
+// Wraps a JIT-ed object that can be executed "on device". We JIT for the host
+// architecture, so JIT-ed code and host code share the same ABI.
+class CpuExecutable : public Executable {
+ public:
+ CpuExecutable(
+ std::unique_ptr<SimpleOrcJIT> jit,
+ std::unique_ptr<BufferAssignment> assignment,
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ const string& entry_function_name,
+ std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx);
+ ~CpuExecutable() override {}
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ Status ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments) override;
+
+ // This should be called after set_ir_module_string.
+ const string& ir_module_string() const { return ir_module_string_; }
+
+ void set_ir_module_string(const string& ir_module_string) {
+ ir_module_string_ = ir_module_string;
+ }
+
+ private:
+ // Allocate buffers required for execution and assign them to the elements of
+ // "buffers". "buffers" should be sized to the number of buffers in buffer
+ // assignment. Each vector element corresponds to a particular Index. If
+ // a vector element already contains a non-null DeviceMemoryBase, then no
+ // buffer is assigned for this element.
+ Status AllocateBuffers(
+ DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+ std::vector<perftools::gputools::DeviceMemoryBase>* buffers);
+
+ // Calls the generated function performing the computation with the given
+ // arguments using the supplied buffers.
+ Status ExecuteComputeFunction(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ buffers,
+ HloExecutionProfile* hlo_execution_profile);
+ Status ExecuteComputeFunction(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ buffers,
+ HloExecutionProfile* hlo_execution_profile);
+
+ // Returns the points-to set of the root instruction of the entry
+ // computation. Uses points-to analysis from buffer assignment.
+ const PointsToSet& GetRootPointsToSet() const;
+
+ // The JIT containing compiled modules.
+ std::unique_ptr<SimpleOrcJIT> jit_;
+
+ // Buffer assignment for the buffers we need to allocate.
+ std::unique_ptr<BufferAssignment> assignment_;
+
+ // The LLVM IR, in string format, of the unoptimized module generated for this
+ // CpuExecutable. We save a string instead of an llvm::Module* because leaving
+ // llvm::Module* in a singleton can cause the heap checker to emit false
+ // positives.
+ string ir_module_string_;
+
+ // Type of the computation function we expect in the JIT.
+ // void function(void* result, const void* run_options,
+ // const void** args_array, void** temps_array)
+ using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
+ uint64*);
+ ComputeFunctionType compute_function_;
+
+ // Entry function name for the computation.
+ const string entry_function_name_;
+
+ // Maps HLOs to their index into the profile counter array.
+ const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CpuExecutable);
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
new file mode 100644
index 0000000000..240da35ef1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc
@@ -0,0 +1,44 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+namespace xla {
+namespace cpu {
+
+bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
+ int64 operand_index) {
+ HloInstruction* producer = consumer->mutable_operand(operand_index);
+
+ // Condition for consumer: must be elementwise or a fusion op
+ // (which necessarily only contains elementwise operations)
+ if (!(consumer->opcode() == HloOpcode::kFusion ||
+ consumer->IsElementwise())) {
+ return false;
+ }
+
+ // Producer or consumer cannot be Map. Maps are technically elementwise but
+ // of a slightly different form (call instead of a computation). These are not
+ // yet supported in the CPU backend.
+ return producer->IsElementwise() && producer->operand_count() > 0 &&
+ producer->opcode() != HloOpcode::kMap &&
+ consumer->opcode() != HloOpcode::kMap &&
+ InstructionFusion::ShouldFuse(consumer, operand_index);
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h
new file mode 100644
index 0000000000..b7c646ad47
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_INSTRUCTION_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_INSTRUCTION_FUSION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/instruction_fusion.h"
+
+namespace xla {
+namespace cpu {
+
+class CpuInstructionFusion : public InstructionFusion {
+ public:
+ CpuInstructionFusion() {}
+ ~CpuInstructionFusion() override {}
+
+ protected:
+ bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_INSTRUCTION_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
new file mode 100644
index 0000000000..7ae81929c0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+namespace cpu {
+
+StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
+ bool changed = false;
+ HloComputation* entry_computation = module->entry_computation();
+ std::unordered_set<HloInstruction*> outlined;
+ std::vector<HloInstruction*> instructions_to_outline;
+ for (HloInstruction* instruction :
+ entry_computation->MakeInstructionPostOrder()) {
+ // If the instruction has been outlined, it no longer exists and we must not
+ // dereference it.
+ if (outlined.count(instruction) > 0) {
+ continue;
+ }
+
+ // Skip parameters and constants, there is nothing to parallelize.
+ if (instruction->opcode() == HloOpcode::kParameter ||
+ instruction->opcode() == HloOpcode::kConstant) {
+ continue;
+ }
+ instructions_to_outline.clear();
+ HloInstruction* outline_candidate = instruction;
+ instructions_to_outline.push_back(outline_candidate);
+ bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast;
+
+ // Outline sole users with the current instruction.
+ while (outline_candidate->users().size() == 1) {
+ HloInstruction* prior_candidate = outline_candidate;
+ outline_candidate = *outline_candidate->users().begin();
+ all_bitcasts |= outline_candidate->opcode() == HloOpcode::kBitcast;
+ if (std::any_of(outline_candidate->operands().begin(),
+ outline_candidate->operands().end(),
+ [&](const HloInstruction* operand) {
+ // Do not consider any candidates which have operands
+ // other than the prior candidate, constants or
+ // parameters. Otherwise, we'd increase the fan-in which
+ // would reduce parallelism.
+ return operand->opcode() != HloOpcode::kParameter &&
+ operand->opcode() != HloOpcode::kConstant &&
+ operand != prior_candidate;
+ })) {
+ break;
+ }
+ instructions_to_outline.push_back(outline_candidate);
+ }
+ // If all instructions in the outline candidates are a bitcast, then create
+ // a copy at the head of the bitcasts and include it in the outlined
+ // instructions. The underlying problem is that a computation which forwards
+ // a parameter buffer to the output is not properly handled by the backends
+ // or analysis.
+ //
+ // This would be better handled by being smarter about choosing outline
+ // candidates in the first place.
+ if (all_bitcasts) {
+ // 'head' is the first instruction in the chain of bitcasts.
+ HloInstruction* head = instructions_to_outline[0];
+ HloInstruction* head_operand = head->mutable_operand(0);
+ HloInstruction* copy =
+ entry_computation->AddInstruction(HloInstruction::CreateUnary(
+ head_operand->shape(), HloOpcode::kCopy, head_operand));
+ head->ReplaceOperandWith(0, copy);
+ instructions_to_outline.insert(instructions_to_outline.begin(), copy);
+ }
+
+ outlined.insert(instructions_to_outline.begin(),
+ instructions_to_outline.end());
+
+ module->OutlineExpressionFromComputation(
+ instructions_to_outline,
+ tensorflow::strings::StrCat("computation_for_", instruction->name()),
+ entry_computation);
+ changed = true;
+ }
+
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+ for (auto& computation : module->computations()) {
+ HloInstruction* root = computation->root_instruction();
+ // Copy root instruction if it does not define its own top-level buffer.
+ // TODO(b/32885001) Remove these copies (at least for the unambiguous case).
+ // TODO(b/32885001) Perform shallow copy if root value is a tuple.
+ if (!points_to_analysis->InstructionDefinesBufferAtIndex(root,
+ /*index=*/{})) {
+ HloInstruction* copy = computation->AddInstruction(
+ HloInstruction::CreateUnary(root->shape(), HloOpcode::kCopy, root));
+ computation->set_root_instruction(copy);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
new file mode 100644
index 0000000000..3d6cfb258f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+namespace cpu {
+
+// This pass prepares an HLO module for parallel execution by transforming
+// subgraphs of the top-level computation into embedded computations which can
+// be executed in parallel.
+// TODO(b/29630486): Currently, it is limited to turning all instructions (which
+// are not constants or parameters) in the entry computation into embedded
+// computations. However, it could make sense to coarsen the parallelization to
+// improve cache locality. Also, we will need to do something to intelligently
+// handle While constructs.
+class ParallelizationPreparation : public HloPass {
+ public:
+ explicit ParallelizationPreparation() : HloPass("cpu-parallel-prepare") {}
+ ~ParallelizationPreparation() override {}
+
+ // Run instruction fusion on the given computation. Returns whether the
+ // computation was changed.
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
new file mode 100644
index 0000000000..8e06f0520e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+
+#include <sched.h>
+#include <functional>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+InfeedManager* GetInfeedManager() {
+ static InfeedManager* manager = new InfeedManager;
+ return manager;
+}
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
+
+void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ xla::int32 buffer_length) {
+ xla::cpu::runtime::InfeedManager* infeed =
+ xla::cpu::runtime::GetInfeedManager();
+ // Wait until there's a buffer to dequeue.
+ xla::cpu::runtime::InfeedBuffer* buffer = infeed->BlockingDequeueBuffer();
+ CHECK_EQ(buffer->length(), buffer_length);
+ return buffer->data();
+}
+
+void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(xla::int32 buffer_length,
+ void* buffer_ptr) {
+ xla::cpu::runtime::InfeedManager* infeed =
+ xla::cpu::runtime::GetInfeedManager();
+ infeed->ReleaseCurrentBuffer(buffer_length, buffer_ptr);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
new file mode 100644
index 0000000000..8eae210230
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This header declares functions which may be called by the generated code on
+// the CPU. Calls to these functions must be resolved explicitly in the JIT in
+// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
+// which is used to cache expensive state and resources utilized by the
+// aforementioned functions.
+//
+// Other functions are declared in individual libraries as well, such as
+// runtime_conv2d and runtime_matmul. As individual libraries, callers for
+// ahead-of-time compilation can link only the required subset.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
+
+#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+// Names of runtime functions. These get resolved from the generated code to the
+// right symbol at link time in one of two ways:
+// 1. When using the JIT, the symbol resolver (SimpleResolver in
+// third_party/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc) maps
+// this symbol name to
+// the actual symbol.
+// 2. When using ahead-of-time compilation, the linker can resolve the name
+// because it is a symbol in the cpu_runtime library.
+constexpr char kEigenMatmulF32SymbolName[] = "__xla_cpu_runtime_EigenMatMulF32";
+constexpr char kEigenMatmulF64SymbolName[] = "__xla_cpu_runtime_EigenMatMulF64";
+constexpr char kEigenConvF32SymbolName[] = "__xla_cpu_runtime_EigenConvF32";
+constexpr char kEigenSingleThreadedMatmulF32SymbolName[] =
+ "__xla_cpu_runtime_EigenSingleThreadedMatMulF32";
+constexpr char kEigenSingleThreadedMatmulF64SymbolName[] =
+ "__xla_cpu_runtime_EigenSingleThreadedMatMulF64";
+constexpr char kEigenSingleThreadedConvF32SymbolName[] =
+ "__xla_cpu_runtime_EigenSingleThreadedConvF32";
+constexpr char kAcquireInfeedBufferForDequeueSymbolName[] =
+ "__xla_cpu_runtime_AcquireInfeedBufferForDequeue";
+constexpr char kReleaseInfeedBufferAfterDequeueSymbolName[] =
+ "__xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue";
+
+// Returns the infeed manager used by the CPU runtime.
+InfeedManager* GetInfeedManager();
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
+
+extern "C" {
+
+// Blocks until the next infeed buffer is ready to be dequeued, then
+// returns it. Fails catastrophically if the next enqueued buffer is
+// not of the correct length in bytes. Checking the shape rather than
+// the length would be more exact, but the length check is chosen as a
+// tradeoff between error checking and speed/simplicity.
+extern void* __xla_cpu_runtime_AcquireInfeedBufferForDequeue(
+ xla::int32 buffer_length);
+
+// Relinquishes the next infeed buffer that was returned by
+// __xla_cpu_runtime_AcquireInfeedBufferForDequeue. Once this call
+// completes the data at buffer_ptr may no longer be
+// accessed. buffer_length must match the length passed to the call to
+// __xla_cpu_runtime_AcquireInfeedBufferForDequeue that returned
+// buffer_ptr. This function must be called before the next buffer is
+// acquired, i.e., there may only be one outstanding infeed buffer in
+// use by the runtime. TODO(b/31340454) investigate whether or not it
+// is worth supporting zero-copy infeed where the buffer is retained
+// by the compiled code until it has been used. If zero-copy infeed is
+// implemented we will add support for multiple outstanding buffers
+// that can be returned out of order.
+extern void __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(
+ xla::int32 buffer_length, void* buffer_ptr);
+}
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc
new file mode 100644
index 0000000000..646254887c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/Eigen/Core"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+#ifdef __AVX__
+V8F32 ExpV8F32(V8F32 x) { return Eigen::internal::pexp(x); }
+
+V8F32 LogV8F32(V8F32 x) { return Eigen::internal::plog(x); }
+
+V8F32 TanhV8F32(V8F32 x) { return Eigen::internal::ptanh(x); }
+#endif // __AVX__
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h
new file mode 100644
index 0000000000..89721aaf83
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This header declares functions which may be called by the generated code on
+// the CPU. Calls to these functions must be resolved explicitly in the JIT in
+// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
+// which is used to cache expensive state and resources utilized by the
+// aforementioned functions.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
+
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+constexpr char kExpV8F32[] = "__xla_cpu_runtime_ExpV8F32";
+constexpr char kLogV8F32[] = "__xla_cpu_runtime_LogV8F32";
+constexpr char kTanhV8F32[] = "__xla_cpu_runtime_TanhV8F32";
+
+typedef float V8F32 __attribute__((__vector_size__(32)));
+
+// The following functions are vectorized versions of a selection of libm
+// library functions.
+// References to these functions are created by the LLVM vectorizer.
+V8F32 ExpV8F32(V8F32 x) TF_ATTRIBUTE_WEAK;
+
+V8F32 LogV8F32(V8F32 x) TF_ATTRIBUTE_WEAK;
+
+V8F32 TanhV8F32(V8F32 x) TF_ATTRIBUTE_WEAK;
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc
new file mode 100644
index 0000000000..69d04427c6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/Eigen/Core"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+#ifdef __SSE4_1__
+
+V4F32 ExpV4F32(V4F32 x) {
+ Eigen::internal::Packet4f p = x;
+ return Eigen::internal::pexp(p);
+}
+
+V4F32 LogV4F32(V4F32 x) {
+ Eigen::internal::Packet4f p = x;
+ return Eigen::internal::plog(p);
+}
+
+V4F32 TanhV4F32(V4F32 x) {
+ Eigen::internal::Packet4f p = x;
+ return Eigen::internal::ptanh(p);
+}
+
+#endif // __SSE4_1__
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h
new file mode 100644
index 0000000000..ded206f90a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This header declares functions which may be called by the generated code on
+// the CPU. Calls to these functions must be resolved explicitly in the JIT in
+// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
+// which is used to cache expensive state and resources utilized by the
+// aforementioned functions.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
+
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+constexpr char kExpV4F32[] = "__xla_cpu_runtime_ExpV4F32";
+constexpr char kLogV4F32[] = "__xla_cpu_runtime_LogV4F32";
+constexpr char kTanhV4F32[] = "__xla_cpu_runtime_TanhV4F32";
+
+typedef float V4F32 __attribute__((__vector_size__(16)));
+
+// The following functions are vectorized versions of a selection of libm
+// library functions.
+// References to these functions are created by the LLVM vectorizer.
+V4F32 ExpV4F32(V4F32 x) TF_ATTRIBUTE_WEAK;
+
+V4F32 LogV4F32(V4F32 x) TF_ATTRIBUTE_WEAK;
+
+V4F32 TanhV4F32(V4F32 x) TF_ATTRIBUTE_WEAK;
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
new file mode 100644
index 0000000000..52eed7dbad
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+
+#include <memory>
+#include <string>
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CpuRuntimeTest : public ::testing::Test {};
+
+template <typename T>
+std::unique_ptr<Array2D<float>> MaybeTransposeArray2D(const Array2D<T>& array,
+ bool transpose) {
+ int64 output_height = array.height();
+ int64 output_width = array.width();
+ if (transpose) {
+ std::swap(output_width, output_height);
+ }
+ auto output = MakeUnique<Array2D<float>>(output_height, output_width);
+ for (int y = 0; y < array.height(); y++) {
+ for (int x = 0; x < array.width(); x++) {
+ if (transpose) {
+ (*output)(x, y) = array(y, x);
+ } else {
+ (*output)(y, x) = array(y, x);
+ }
+ }
+ }
+ return output;
+}
+
+// Verifies that matrix 'c' equals the result of matrix 'a' times matrix 'b'.
+// Each element is compared to within a small error bound.
+void CheckMatrixMultiply(const Array2D<float>& a, const Array2D<float>& b,
+ const Array2D<float>& c) {
+ for (int i = 0; i < a.height(); ++i) {
+ for (int j = 0; j < b.width(); ++j) {
+ float sum = 0.0;
+ for (int k = 0; k < a.width(); ++k) {
+ sum += a(i, k) * b(k, j);
+ }
+ EXPECT_NEAR(sum, c(i, j), 0.01);
+ }
+ }
+}
+
+std::unique_ptr<Array2D<float>> EigenMatrixMultiply(const Array2D<float>& a,
+ const Array2D<float>& b,
+ bool transpose_lhs,
+ bool transpose_rhs) {
+ tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
+ 2);
+ tensorflow::EigenThreadPoolWrapper tp(&pool);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+ ExecutableRunOptions run_options;
+ run_options.set_intra_op_thread_pool(&device);
+
+ CHECK_EQ(a.width(), b.height());
+ int64 m = a.height();
+ int64 n = b.width();
+ int64 k = a.width();
+
+ // The Eigen matmul runtime function expects the matrix to be in column major
+ // order and array2d is in row-major order. Create transposes of a and b. The
+ // 'data' buffer in the transposed array is the original array in column major
+ // order.
+ auto a_transpose = MaybeTransposeArray2D(a, !transpose_lhs);
+ auto b_transpose = MaybeTransposeArray2D(b, !transpose_rhs);
+
+ // Since we're going to transpose c before returning it. Swap the order of the
+ // dimension sizes to ensure the returned array is properly dimensioned.
+ auto c_transpose = MakeUnique<Array2D<float>>(n, m);
+ __xla_cpu_runtime_EigenMatMulF32(&run_options, c_transpose->data(),
+ a_transpose->data(), b_transpose->data(), m,
+ n, k, transpose_lhs, transpose_rhs);
+ return MaybeTransposeArray2D(*c_transpose, true);
+}
+
+TEST_F(CpuRuntimeTest, SmallEigenMatmul) {
+ Array2D<float> a({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ Array2D<float> b({{5.0f, -1.0f, 3.0f}, {2.0f, 6.0f, 4.0f}});
+
+ for (bool transpose_lhs : {false, true}) {
+ for (bool transpose_rhs : {false, true}) {
+ auto c = EigenMatrixMultiply(a, b, transpose_lhs, transpose_rhs);
+
+ LOG(INFO) << "a = " << a.ToString();
+ LOG(INFO) << "b = " << b.ToString();
+ LOG(INFO) << "c = " << c->ToString();
+
+ CheckMatrixMultiply(a, b, *c);
+ }
+ }
+}
+
+TEST_F(CpuRuntimeTest, LargeEigenMatmul) {
+ auto a = MakeLinspaceArray2D(0.0, 1.0, 256, 512);
+ auto b = MakeLinspaceArray2D(-2.0, 2.0, 512, 1024);
+
+ for (bool transpose_lhs : {false, true}) {
+ for (bool transpose_rhs : {false, true}) {
+ auto c = EigenMatrixMultiply(*a, *b, transpose_lhs, transpose_rhs);
+
+ CheckMatrixMultiply(*a, *b, *c);
+ }
+ }
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.cc b/tensorflow/compiler/xla/service/cpu/disassembler.cc
new file mode 100644
index 0000000000..f0dcce56b4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/disassembler.cc
@@ -0,0 +1,182 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
+
+#include <stdint.h>
+#include <algorithm>
+// IWYU pragma: no_include <system_error>
+#include <type_traits>
+#include <vector>
+
+#include "external/llvm/include/llvm/MC/MCInst.h"
+#include "external/llvm/include/llvm/Support/TargetRegistry.h"
+#include "external/llvm/include/llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace cpu {
+
+Disassembler::Disassembler(const llvm::TargetMachine& target_machine)
+ : subtarget_info_(*target_machine.getMCSubtargetInfo()) {
+ objfile_info_.reset(new llvm::MCObjectFileInfo());
+ mc_context_.reset(new llvm::MCContext(target_machine.getMCAsmInfo(),
+ target_machine.getMCRegisterInfo(),
+ objfile_info_.get()));
+ disassembler_.reset(target_machine.getTarget().createMCDisassembler(
+ subtarget_info_, *mc_context_));
+ inst_printer_.reset(target_machine.getTarget().createMCInstPrinter(
+ target_machine.getTargetTriple(),
+ /*SyntaxVariant=*/0, // Use AT&T syntax.
+ *target_machine.getMCAsmInfo(), *target_machine.getMCInstrInfo(),
+ *target_machine.getMCRegisterInfo()));
+ inst_analysis_.reset(target_machine.getTarget().createMCInstrAnalysis(
+ target_machine.getMCInstrInfo()));
+}
+
+// This code is based on llvm-objdump in llvm/tools.
+StatusOr<string> Disassembler::DisassembleObjectFile(
+ const llvm::object::ObjectFile& object_file) const {
+ if (disassembler_ == nullptr) {
+ return NotFound("could not find a disassembler for this platform");
+ }
+
+ std::string buffer_string;
+ llvm::raw_string_ostream ostream(buffer_string);
+
+ // Iterate through sections. Disassemble symbols of the text section(s).
+ for (auto& section : object_file.sections()) {
+ if (!section.isText()) {
+ continue;
+ }
+
+ // Gather symbols from the section.
+ std::vector<llvm::object::SymbolRef> symbols;
+ for (auto& symbol : object_file.symbols()) {
+ if (section.containsSymbol(symbol)) {
+ symbols.push_back(symbol);
+ }
+ }
+
+ // Sort the symbols in increasing address order.
+ std::sort(
+ symbols.begin(), symbols.end(),
+ [](const llvm::object::SymbolRef& a, const llvm::object::SymbolRef& b) {
+ // getAddress returns a Expected object. Assert there is no error
+ // before extracting the address.
+ llvm::Expected<uint64_t> a_address_or_error = a.getAddress();
+ CHECK(a_address_or_error);
+ llvm::Expected<uint64_t> b_address_or_error = b.getAddress();
+ CHECK(b_address_or_error);
+ return a_address_or_error.get() < b_address_or_error.get();
+ });
+
+ // Construct ArrayRef pointing to section contents.
+ llvm::StringRef section_content_string;
+ if (section.getContents(section_content_string)) {
+ continue;
+ }
+ llvm::ArrayRef<uint8_t> section_content_bytes(
+ reinterpret_cast<const uint8*>(section_content_string.data()),
+ section_content_string.size());
+
+ // Use int types from LLVM (eg, uint64_t) for values passed to and returned
+ // from the LLVM API. These values map to different types in LLVM and
+ // XLA (unsigned long vs unsigned long long).
+ uint64_t section_address = section.getAddress();
+ uint64_t section_size = section.getSize();
+
+ // Iterate through symbols in increasing address order and disassemble each
+ // one.
+ for (int i = 0; i < symbols.size(); ++i) {
+ auto symbol = symbols[i];
+ llvm::Expected<uint64_t> address = symbol.getAddress();
+ CHECK(address);
+ uint64_t start_index = address.get() - section_address;
+
+ // End of symbol is either the end of the section or the start of the next
+ // symbol.
+ uint64_t end_index;
+ if (i < symbols.size() - 1) {
+ llvm::Expected<uint64_t> next_address = symbols[i + 1].getAddress();
+ CHECK(next_address);
+ end_index = std::min(section_size, next_address.get());
+ } else {
+ end_index = section_size;
+ }
+
+ // Skip zero-length symbols.
+ if (start_index == end_index) {
+ continue;
+ }
+
+ llvm::Expected<llvm::StringRef> name_or_error = symbol.getName();
+ TF_RET_CHECK(name_or_error);
+ ostream << name_or_error.get().str() << ":\n";
+
+ // Disassemble symbol instruction-by-instruction.
+ uint64_t index = start_index;
+ while (index < end_index) {
+ llvm::MCInst instruction;
+ uint64_t size;
+ llvm::MCDisassembler::DecodeStatus decode_status =
+ disassembler_->getInstruction(instruction, size,
+ section_content_bytes.slice(index),
+ /*Address=*/section_address + index,
+ /*VStream=*/llvm::nulls(),
+ /*CStream=*/llvm::nulls());
+ // If we fail to disassemble, then we must skip past this address.
+ if (size == 0) {
+ size = 1;
+ }
+
+ ostream << tensorflow::strings::Printf("0x%08lx", index) << " ";
+
+ if (decode_status == llvm::MCDisassembler::Success) {
+ // For branches, try to determine the actual address and emit it as an
+ // annotation.
+ string annotation;
+ if (inst_analysis_ &&
+ (inst_analysis_->isUnconditionalBranch(instruction) ||
+ inst_analysis_->isConditionalBranch(instruction))) {
+ uint64_t target;
+ if (inst_analysis_->evaluateBranch(
+ instruction, section_address + index, size, target)) {
+ annotation = tensorflow::strings::Printf("[0x%08lx]", target);
+ }
+ }
+ inst_printer_->printInst(&instruction, ostream, annotation.c_str(),
+ subtarget_info_);
+ } else {
+ ostream << " <unknown>";
+ }
+
+ ostream << "\n";
+ index += size;
+ }
+ }
+ }
+
+ ostream.flush();
+ return string(buffer_string.data(), buffer_string.length());
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h
new file mode 100644
index 0000000000..e90f26fc82
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/disassembler.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DISASSEMBLER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DISASSEMBLER_H_
+
+#include <memory>
+#include <string>
+
+#include "external/llvm/include/llvm/MC/MCContext.h"
+#include "external/llvm/include/llvm/MC/MCDisassembler/MCDisassembler.h"
+#include "external/llvm/include/llvm/MC/MCInstPrinter.h"
+#include "external/llvm/include/llvm/MC/MCInstrAnalysis.h"
+#include "external/llvm/include/llvm/MC/MCObjectFileInfo.h"
+#include "external/llvm/include/llvm/MC/MCSubtargetInfo.h"
+#include "external/llvm/include/llvm/Object/ObjectFile.h"
+#include "external/llvm/include/llvm/Target/TargetMachine.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace cpu {
+
+// Class for disassembling object files (and potentially other constructs) into
+// X86 assembly. Builds all the LLVM disassembly and instruction printing
+// constructs from a given TargetMachine.
+class Disassembler {
+ public:
+ explicit Disassembler(const llvm::TargetMachine& target_machine);
+
+ // Returns a string containing the disassembled text sections of the given
+ // object file.
+ //
+ // If we couldnt' retrieve a disassembler for this platform, an error status
+ // is returned.
+ StatusOr<string> DisassembleObjectFile(
+ const llvm::object::ObjectFile& object_file) const;
+
+ private:
+ const llvm::MCSubtargetInfo& subtarget_info_;
+ std::unique_ptr<llvm::MCObjectFileInfo> objfile_info_;
+ std::unique_ptr<llvm::MCContext> mc_context_;
+ std::unique_ptr<llvm::MCDisassembler> disassembler_;
+ std::unique_ptr<llvm::MCInstPrinter> inst_printer_;
+ std::unique_ptr<llvm::MCInstrAnalysis> inst_analysis_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DISASSEMBLER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
new file mode 100644
index 0000000000..420f9cebc5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -0,0 +1,346 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
+
+#include <memory>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+using llvm_ir::SetToFirstInsertPoint;
+
+namespace cpu {
+
+DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
+ bool transpose_rhs,
+ const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array,
+ const llvm_ir::IrArray& rhs_array,
+ llvm::Value* executable_run_options_value,
+ llvm::IRBuilder<>* ir_builder)
+ : dot_(dot),
+ transpose_lhs_(transpose_lhs),
+ transpose_rhs_(transpose_rhs),
+ target_array_(target_array),
+ lhs_array_(lhs_array),
+ rhs_array_(rhs_array),
+ executable_run_options_value_(executable_run_options_value),
+ ir_builder_(ir_builder) {}
+
+/* static */ tensorflow::Status DotOpEmitter::EmitDotOperation(
+ const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
+ const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
+ const llvm_ir::IrArray& rhs_array,
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder) {
+ PrimitiveType type = target_array.GetShape().element_type();
+ TF_RET_CHECK(F32 == type || F64 == type);
+ DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array,
+ lhs_array, rhs_array, executable_run_options_value,
+ ir_builder);
+ return dot_emitter.Emit();
+}
+
+bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
+
+tensorflow::Status DotOpEmitter::Emit() {
+ // The dot operation performs a sum of products over dimension 0 of the left
+ // hand side operand and dimension 1 of the right hand side operand.
+ //
+ // Let the shapes of lhs and rhs be defined as below:
+ //
+ // lhs = [L{n-1} x L{n-2} x ... L{0}]
+ // rhs = [R{m-1} x R{m-2} x ... R{0}]
+ //
+ // The sum-of-products dimension in the lhs has size L{0} and the dimension in
+ // the rhs has size R{1}. Necessarily, then:
+ //
+ // L{0} == R{1}
+ //
+ // The output of the operation has the following shape:
+ //
+ // output = [L{n-1} x L{n-2} x ... L{1} x R{m-1} x R{m-2} x ... R{2} x R{0}]
+ //
+ // To perform the operation we construct a loop nest with one for-loop for
+ // each dimension of the output. Inside this loop nest is another for-loop
+ // which performs the sum-of-products (the reduction loop) before storing
+ // the result in the output buffer.
+
+ const Shape& lhs_shape = lhs_array_.GetShape();
+ const Shape& rhs_shape = rhs_array_.GetShape();
+
+ if (ShapeUtil::IsScalar(lhs_shape) || ShapeUtil::IsScalar(rhs_shape)) {
+ // If the operands are scalar, don't emit any loops.
+ TF_RET_CHECK(ShapeUtil::IsScalar(lhs_shape) &&
+ ShapeUtil::IsScalar(rhs_shape));
+ return EmitScalarDot();
+ }
+
+ if (PotentiallyImplementedAsEigenDot(dot_)) {
+ return EmitCallToRuntime();
+ }
+
+ // Reduce along dimension 0 of the LHS and 1 of the RHS. Vectors are a special
+ // case where the reduction dimension is 0 for both LHS and RHS. This results
+ // in a vector dot product producing a scalar.
+ int64 lhs_reduction_dimension = 0;
+ if (ShapeUtil::Rank(lhs_shape) >= 2) {
+ lhs_reduction_dimension =
+ ShapeUtil::GetDimensionNumber(lhs_shape, transpose_lhs_ ? -2 : -1);
+ }
+ int64 rhs_reduction_dimension = 0;
+ if (ShapeUtil::Rank(rhs_shape) >= 2) {
+ rhs_reduction_dimension =
+ ShapeUtil::GetDimensionNumber(rhs_shape, transpose_rhs_ ? -1 : -2);
+ }
+
+ // Verify the reduction dimension in the two operands are the same size.
+ TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
+ rhs_shape.dimensions(rhs_reduction_dimension));
+
+ // Create loop nests which loop through the LHS operand dimensions and the RHS
+ // operand dimensions. The reduction dimension of the LHS and RHS are handled
+ // in a separate innermost loop which performs the sum of products.
+ llvm_ir::ForLoopNest loop_nest(ir_builder_);
+ llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
+ &loop_nest, lhs_array_, lhs_reduction_dimension, "lhs");
+ llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
+ &loop_nest, rhs_array_, rhs_reduction_dimension, "rhs");
+
+ // Create the loop which does the sum of products reduction.
+ std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
+ 0, lhs_shape.dimensions(lhs_reduction_dimension), "reduction");
+
+ // The final entry in the rhs and lhs indexes is the indvar of the
+ // reduction loop.
+ lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
+ rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
+
+ // For computing the sum of products we alloca a single location to store the
+ // dot product result as we accumulate it within the reduction loop. After the
+ // reduction loop we load the result and store into the output array.
+
+ // Function entry basic block.
+ // - Emit alloca for accumulator
+ llvm::Function* func = reduction_loop->GetPreheaderBasicBlock()->getParent();
+ SetToFirstInsertPoint(&func->getEntryBlock(), ir_builder_);
+ llvm::Type* accum_type = target_array_.GetElementLlvmType();
+ llvm::Value* accum_address = ir_builder_->CreateAlloca(
+ accum_type, /*ArraySize=*/nullptr, "accum_address");
+
+ // Preheader basic block of reduction loop:
+ // - Initialize accumulator to zero.
+ llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock();
+ ir_builder_->SetInsertPoint(preheader_bb->getTerminator());
+
+ ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0),
+ accum_address);
+
+ // Body basic block of reduction loop:
+ // - Load elements from lhs and rhs array.
+ // - Multiply lhs-element and rhs-element.
+ // - Load accumulator and add to product.
+ // - Store sum back into accumulator.
+ SetToFirstInsertPoint(reduction_loop->GetBodyBasicBlock(), ir_builder_);
+
+ llvm::Value* lhs_element =
+ lhs_array_.EmitReadArrayElement(lhs_index, ir_builder_);
+ llvm::Value* rhs_element =
+ rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_);
+
+ llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element);
+ llvm::Value* accum = ir_builder_->CreateLoad(accum_address);
+ llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product);
+ ir_builder_->CreateStore(updated_accum, accum_address);
+
+ // Exit basic block of reduction loop.
+ // - Load accumulator value (the result).
+ // - Store into output array.
+ SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), ir_builder_);
+
+ llvm::Value* result = ir_builder_->CreateLoad(accum_address);
+
+ // Create index into target address. The target index is the concatenation of
+ // the rhs and lhs indexes with the reduction dimensions removed. The terms
+ // from the rhs index are the lower dimensions in the index so we add them
+ // first.
+ llvm_ir::IrArray::Index target_index;
+ for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
+ if (dimension != lhs_reduction_dimension) {
+ target_index.push_back(lhs_index[dimension]);
+ }
+ }
+ for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
+ if (dimension != rhs_reduction_dimension) {
+ target_index.push_back(rhs_index[dimension]);
+ }
+ }
+
+ target_array_.EmitWriteArrayElement(target_index, result, ir_builder_);
+
+ // Set the IR builder insert point to the exit basic block of the outer most
+ // loop.
+ ir_builder_->SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status DotOpEmitter::EmitScalarDot() {
+ // A scalar dot is just a scalar multiply.
+ llvm::Value* lhs_value =
+ lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
+ llvm::Value* rhs_value =
+ rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_);
+ llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value);
+ target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status DotOpEmitter::EmitCallToRuntime() {
+ DCHECK(ShapesAreLegalForRuntimeDot());
+
+ // The signature of the Eigen runtime matmul function is:
+ //
+ // (void)(void* run_options, float* out, float* lhs, float* rhs,
+ // int64 m, int64 n, int64 k, int32 transpose_lhs,
+ // int32 transpose_rhs);
+ // The two transpose_... parameters are actually booleans, but we use int32
+ // to avoid target-dependent calling convention details.
+
+ legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ bool multi_threaded = flags->xla_cpu_multi_thread_eigen;
+ PrimitiveType type = target_array_.GetShape().element_type();
+ llvm::Type* float_type;
+ const char* fn_name;
+ switch (type) {
+ case F32:
+ fn_name = multi_threaded
+ ? runtime::kEigenMatmulF32SymbolName
+ : runtime::kEigenSingleThreadedMatmulF32SymbolName;
+ float_type = ir_builder_->getFloatTy();
+ break;
+ case F64:
+ fn_name = multi_threaded
+ ? runtime::kEigenMatmulF64SymbolName
+ : runtime::kEigenSingleThreadedMatmulF64SymbolName;
+ float_type = ir_builder_->getDoubleTy();
+ break;
+ default:
+ return Unimplemented("Invalid type %s for dot operation",
+ PrimitiveType_Name(type).c_str());
+ }
+
+ llvm::Type* float_ptr_type = float_type->getPointerTo();
+ llvm::Type* int64_type = ir_builder_->getInt64Ty();
+ llvm::Type* int32_type = ir_builder_->getInt32Ty();
+ llvm::Type* int8_ptr_type = ir_builder_->getInt8Ty()->getPointerTo();
+ llvm::FunctionType* matmul_type = llvm::FunctionType::get(
+ ir_builder_->getVoidTy(),
+ {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
+ int64_type, int64_type, int64_type, int32_type, int32_type},
+ /*isVarArg=*/false);
+
+ llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
+ llvm::Module* module = function->getParent();
+
+ llvm::Function* matmul_func = llvm::cast<llvm::Function>(
+ module->getOrInsertFunction(fn_name, matmul_type));
+ matmul_func->setCallingConv(llvm::CallingConv::C);
+ matmul_func->setDoesNotThrow();
+ matmul_func->setOnlyAccessesArgMemory();
+
+ // The Eigen runtime function expects column-major layout. If the matrices are
+ // row major, then use the following identity to compute the product:
+ //
+ // (A x B)^T = B^T x A^T
+ //
+ // Effectively this involves swapping the 'lhs' with 'rhs' and 'm' with 'n'.
+
+ const Shape& lhs_shape = lhs_array_.GetShape();
+ const Shape& rhs_shape = rhs_array_.GetShape();
+ int64 m = lhs_shape.dimensions(transpose_lhs_ ? 1 : 0);
+ int64 k = lhs_shape.dimensions(transpose_lhs_ ? 0 : 1);
+ int64 n = rhs_shape.dimensions(transpose_rhs_ ? 0 : 1);
+ const llvm_ir::IrArray* lhs = &lhs_array_;
+ const llvm_ir::IrArray* rhs = &rhs_array_;
+ bool transpose_lhs = transpose_lhs_;
+ bool transpose_rhs = transpose_rhs_;
+
+ bool is_column_major = lhs_shape.layout().minor_to_major(0) == 0;
+ if (!is_column_major) {
+ std::swap(m, n);
+ std::swap(lhs, rhs);
+ std::swap(transpose_lhs, transpose_rhs);
+ }
+
+ ir_builder_->CreateCall(
+ matmul_func,
+ {ir_builder_->CreateBitCast(executable_run_options_value_, int8_ptr_type),
+ ir_builder_->CreateBitCast(target_array_.GetBasePointer(),
+ float_ptr_type),
+ ir_builder_->CreateBitCast(lhs->GetBasePointer(), float_ptr_type),
+ ir_builder_->CreateBitCast(rhs->GetBasePointer(), float_ptr_type),
+ ir_builder_->getInt64(m), ir_builder_->getInt64(n),
+ ir_builder_->getInt64(k), ir_builder_->getInt32(transpose_lhs),
+ ir_builder_->getInt32(transpose_rhs)});
+ return tensorflow::Status::OK();
+}
+
+llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest(
+ llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
+ int64 reduction_dimension, tensorflow::StringPiece name_suffix) {
+ // Prepares the dimension list we will use to emit the loop nest. Outermost
+ // loops are added first. Add loops in major-to-minor order, and skip the
+ // reduction dimension.
+ std::vector<int64> dimensions;
+ const Shape& shape = operand_array.GetShape();
+ for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) {
+ int64 dimension = shape.layout().minor_to_major(i);
+ if (dimension != reduction_dimension) {
+ dimensions.push_back(dimension);
+ }
+ }
+
+ // Create loop nest with one for-loop for each dimension of the
+ // output.
+ llvm_ir::IrArray::Index index =
+ loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
+ // Verify every dimension except the reduction dimension was set in the index.
+ for (int dimension = 0; dimension < index.size(); ++dimension) {
+ if (dimension == reduction_dimension) {
+ DCHECK_EQ(nullptr, index[dimension]);
+ } else {
+ DCHECK_NE(nullptr, index[dimension]);
+ }
+ }
+ return index;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
new file mode 100644
index 0000000000..44dfe5f2a9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace cpu {
+
+// Helper class for emitting LLVM IR to perform the dot operation.
+class DotOpEmitter {
+ public:
+ // Emit LLVM IR to perform the dot operation on lhs_array and rhs_array and
+ // place the result in target_array. IR is emitted at current insert point of
+ // the builder. Upon completion of the method, the insert point is set to the
+ // end of all instructions emitted for this operation.
+ static tensorflow::Status EmitDotOperation(
+ const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs,
+ const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array,
+ const llvm_ir::IrArray& rhs_array,
+ llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder);
+
+ private:
+ DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
+ bool transpose_rhs, const llvm_ir::IrArray& target_array,
+ const llvm_ir::IrArray& lhs_array,
+ const llvm_ir::IrArray& rhs_array,
+ llvm::Value* executable_run_options_value,
+ llvm::IRBuilder<>* ir_builder);
+
+ // Emits the IR to perform the dot operation.
+ tensorflow::Status Emit();
+
+ // Emits instructions to perform a scalar dot product (a multiply of the
+ // LHS and RHS) and store the results in the target.
+ tensorflow::Status EmitScalarDot();
+
+ // Emits a call to the CPU runtime to perform the matrix multiply.
+ tensorflow::Status EmitCallToRuntime();
+
+ // Emits a series of nested loops for iterating over an operand array in the
+ // dot operation. Loops are constructed in major to minor dimension layout
+ // order. No loop is emitted for the given reduction_dimension. The function
+ // returns an IrArray index for the given operand_array containing the indvars
+ // of the loops. All dimensions of the index are filled except for the
+ // reduction dimension. name_suffix is the string to append to the names of
+ // LLVM constructs (eg, basic blocks) constructed by this method.
+ llvm_ir::IrArray::Index EmitOperandArrayLoopNest(
+ llvm_ir::ForLoopNest* loop_nest, const llvm_ir::IrArray& operand_array,
+ int64 reduction_dimension, tensorflow::StringPiece name_suffix);
+
+ // Our runtime operation requires that all arrays have the same layout,
+ // no padding, and a rank of two.
+ bool ShapesAreLegalForRuntimeDot() const;
+
+ const HloInstruction& dot_;
+ const bool transpose_lhs_;
+ const bool transpose_rhs_;
+ const llvm_ir::IrArray& target_array_;
+ const llvm_ir::IrArray& lhs_array_;
+ const llvm_ir::IrArray& rhs_array_;
+ llvm::Value* executable_run_options_value_;
+ llvm::IRBuilder<>* ir_builder_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_DOT_OP_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
new file mode 100644
index 0000000000..9b46c35b41
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
+
+#include <string>
+
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace cpu {
+
+StatusOr<llvm::Value*> CpuElementalIrEmitter::EmitFloatUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const {
+ switch (op->opcode()) {
+ case HloOpcode::kTanh: {
+ PrimitiveType element_type = op->shape().element_type();
+ string function_name;
+ switch (element_type) {
+ case F32:
+ function_name = "tanhf";
+ break;
+ case F64:
+ function_name = "tanh";
+ break;
+ default:
+ return Unimplemented("tanh");
+ }
+ // Create function type for the function.
+ llvm::FunctionType* function_type = llvm::FunctionType::get(
+ llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
+ llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_),
+ /*isVarArg=*/false);
+ // Create function declaration for 'tanhf'.
+ llvm::Function* function =
+ llvm::cast<llvm::Function>(module_->getOrInsertFunction(
+ llvm_ir::AsStringRef(function_name), function_type));
+ function->setCallingConv(llvm::CallingConv::C);
+ function->setDoesNotThrow();
+ function->setDoesNotAccessMemory();
+ // Create instruction to call 'tanhf'.
+ return ir_builder_->CreateCall(function, operand_value);
+ }
+ default:
+ return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
+ }
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
new file mode 100644
index 0000000000..5160217674
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+namespace cpu {
+
+class CpuElementalIrEmitter : public ElementalIrEmitter {
+ public:
+ CpuElementalIrEmitter(const HloModuleConfig& module_config,
+ llvm::IRBuilder<>* ir_builder, llvm::Module* module)
+ : ElementalIrEmitter(module_config, module, ir_builder) {}
+
+ protected:
+ StatusOr<llvm::Value*> EmitFloatUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const override;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ELEMENTAL_IR_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager.cc
new file mode 100644
index 0000000000..23a2dfcc32
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/infeed_manager.cc
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+InfeedBuffer::~InfeedBuffer() = default;
+
+InfeedManager::InfeedManager() : current_buffer_(nullptr) {}
+
+void InfeedManager::Reset() {
+ std::unique_lock<std::mutex> l(mu_);
+ CHECK(!current_buffer_);
+ for (auto buffer : enqueued_buffer_) {
+ buffer->Done();
+ }
+ enqueued_buffer_.clear();
+}
+
+void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) {
+ std::unique_lock<std::mutex> l(mu_);
+ bool was_empty = enqueued_buffer_.empty();
+ enqueued_buffer_.push_back(buffer);
+ if (was_empty) {
+ // This has the potential to suffer from the notified thread
+ // immediately trying and failing to acquire mu_, but seems
+ // preferable to the alternative of notifying outside the lock
+ // on every enqueue.
+ cv_.notify_one();
+ }
+}
+
+InfeedBuffer* InfeedManager::BlockingDequeueBuffer() {
+ std::unique_lock<std::mutex> l(mu_);
+ while (enqueued_buffer_.empty()) {
+ cv_.wait(l);
+ }
+ CHECK(!current_buffer_);
+ current_buffer_ = enqueued_buffer_.front();
+ enqueued_buffer_.pop_front();
+ return current_buffer_;
+}
+
+void InfeedManager::ReleaseCurrentBuffer(int32 length, void* data) {
+ std::unique_lock<std::mutex> l(mu_);
+ CHECK(current_buffer_);
+ CHECK_EQ(length, current_buffer_->length());
+ CHECK_EQ(data, current_buffer_->data());
+ current_buffer_->Done();
+ current_buffer_ = nullptr;
+}
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager.h b/tensorflow/compiler/xla/service/cpu/infeed_manager.h
new file mode 100644
index 0000000000..298729f31f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/infeed_manager.h
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This header declares the abstract class for the infeed manager that
+// is used by the CPU runtime to transfer buffers into an executing
+// CPU computation, e.g., to feed data into a while loop.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_
+
+// TODO(misard) Adding NOLINT because as soon as XLA is
+// open-sourced this will use the tensorflow wrapper classes.
+#include <condition_variable> // NOLINT(build/c++11)
+#include <deque>
+#include <mutex> // NOLINT(build/c++11)
+
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace cpu {
+namespace runtime {
+
+// Abstract class defining an infeed buffer that is passed to the
+// runtime by a client. The client manages the storage of the buffer.
+class InfeedBuffer {
+ public:
+ virtual ~InfeedBuffer();
+
+ virtual int32 length() = 0;
+ virtual void* data() = 0;
+ virtual void Done() = 0;
+};
+
+// Client-side class used to enqueue infeed buffers.
+class InfeedManager {
+ public:
+ InfeedManager();
+
+ // Calls the completion callback for any enqueued buffers that have
+ // not been dequeued by the runtime, and empties the infeed
+ // queue. Reset may not be called while a runtime computation is
+ // processing a dequeued buffer. The only safe way to ensure this
+ // condition is to call Reset when no computation is taking place.
+ void Reset();
+
+ // Adds buffer to the infeed queue. buffer->Done will be called when
+ // the buffer will no longer be accessed by the InfeedManager,
+ // either as a result of a call to Reset or because the runtime has
+ // dequeued and used the buffer.
+ void EnqueueBuffer(InfeedBuffer* buffer);
+
+ // Blocks until the infeed queue is non-empty, then returns the
+ // buffer at the head of the queue. Sets the current buffer to be
+ // the returned buffer. It is an error to call BlockingDequeueBuffer
+ // if there is an unreleased current buffer, i.e.,
+ // ReleaseCurrentBuffer must be called between calls to
+ // BlockingDequeueBuffer.
+ InfeedBuffer* BlockingDequeueBuffer();
+
+ // Releases the current buffer, which is the last buffer returned by
+ // BlockingDequeuBuffer and not yet released. length and data must
+ // match the buffer->length() and buffer->data() for the current
+ // buffer.
+ void ReleaseCurrentBuffer(int32 length, void* data);
+
+ private:
+ std::mutex mu_;
+ // Condition variable that is signaled every time a buffer is
+ // enqueued to an empty queue.
+ std::condition_variable cv_;
+ // InfeedBuffer* queue contents are not owned, but buffer->Done must
+ // be called when the buffer is no longer needed by the runtime.
+ std::deque<InfeedBuffer*> enqueued_buffer_;
+ // If non-NULL, the buffer that is currently being processed by the
+ // runtime. Not owned.
+ InfeedBuffer* current_buffer_;
+};
+
+} // namespace runtime
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_INFEED_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc
new file mode 100644
index 0000000000..c65d821660
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/infeed_manager_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h"
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class InfeedManagerTest : public ::testing::Test {};
+
+class TestInfeedBuffer : public cpu::runtime::InfeedBuffer {
+ public:
+ explicit TestInfeedBuffer(int32 length)
+ : done_called_(false), length_(length) {}
+ ~TestInfeedBuffer() override { EXPECT_TRUE(done_called_); }
+
+ int32 length() override { return length_; }
+ void* data() override { return nullptr; }
+ void Done() override {
+ CHECK(!done_called_);
+ done_called_ = true;
+ }
+
+ private:
+ bool done_called_;
+ int32 length_;
+};
+
+void ProcessNextBuffer(int32 length) {
+ void* buffer = __xla_cpu_runtime_AcquireInfeedBufferForDequeue(length);
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue(length, buffer);
+}
+
+TEST_F(InfeedManagerTest, SingleThreadedSequential) {
+ TestInfeedBuffer* a = new TestInfeedBuffer(64);
+ TestInfeedBuffer* b = new TestInfeedBuffer(32);
+
+ cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager();
+
+ infeed->EnqueueBuffer(a);
+ infeed->EnqueueBuffer(b);
+ ProcessNextBuffer(a->length());
+ ProcessNextBuffer(b->length());
+}
+
+TEST_F(InfeedManagerTest, SingleThreadedInterleaved) {
+ TestInfeedBuffer* a = new TestInfeedBuffer(64);
+ TestInfeedBuffer* b = new TestInfeedBuffer(32);
+
+ cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager();
+
+ infeed->EnqueueBuffer(a);
+ ProcessNextBuffer(a->length());
+ infeed->EnqueueBuffer(b);
+ ProcessNextBuffer(b->length());
+}
+
+TEST_F(InfeedManagerTest, MultiThreaded) {
+ tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test", 2);
+
+ cpu::runtime::InfeedManager* infeed = cpu::runtime::GetInfeedManager();
+
+ const int32 length = 64;
+
+ pool.Schedule([infeed]() {
+ // Spin for 100 milliseconds
+ int64 start_micros = tensorflow::Env::Default()->NowMicros();
+ while (true) {
+ int64 end_micros = tensorflow::Env::Default()->NowMicros();
+ if ((end_micros - start_micros) >= 100000) { // 100 ms
+ break;
+ }
+ }
+ TestInfeedBuffer* a = new TestInfeedBuffer(length);
+ infeed->EnqueueBuffer(a);
+ });
+
+ ProcessNextBuffer(length);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
new file mode 100644
index 0000000000..2d855d0eb1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+
+namespace xla {
+namespace cpu {
+
+bool PotentiallyImplementedAsEigenConvolution(
+ const HloInstruction& convolution) {
+ legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ if (!flags->xla_cpu_use_eigen) {
+ return false;
+ }
+
+ // The following conditions are necessary (but not sufficient) for
+ // implementing `convolution` with Eigen convolution:
+ // - the input and kernel have a non-zero number of elements.
+ // - the input is in NHWC or NWHC order.
+ // - the kernel is in HWIO or WHIO order.
+ // - the spatial dimensions are in the same relative order in the input,
+ // kernel and output.
+ //
+ // To be sufficient, certain layout constraints need to be satisfied as well.
+ if (ShapeUtil::HasZeroElements(convolution.operand(0)->shape()) ||
+ ShapeUtil::HasZeroElements(convolution.operand(1)->shape())) {
+ return false;
+ }
+ const ConvolutionDimensionNumbers& dnums =
+ convolution.convolution_dimension_numbers();
+ // Only 2D convolutions are supported at the moment.
+ // TODO(b/32897908): add an optimized implementation for 3D convolution.
+ if (dnums.spatial_dimensions_size() != 2) {
+ return false;
+ }
+ bool input_spatial_dims_ascending =
+ dnums.spatial_dimensions(0) < dnums.spatial_dimensions(1);
+ bool kernel_spatial_dims_ascending =
+ dnums.kernel_spatial_dimensions(0) < dnums.kernel_spatial_dimensions(1);
+ return dnums.batch_dimension() == 0 && dnums.feature_dimension() == 3 &&
+ input_spatial_dims_ascending == kernel_spatial_dims_ascending &&
+ dnums.kernel_input_feature_dimension() == 2 &&
+ dnums.kernel_output_feature_dimension() == 3;
+}
+
+namespace {
+
+// Return whether the given shape is a matrix with no padding.
+bool IsRank2WithNoPadding(const Shape& shape) {
+ return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
+}
+
+// In a gemm operation where output = lhs * rhs, check whether the given shapes
+// are valid for the operation.
+bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
+ const Shape& output_shape) {
+ // The inputs and the output must
+ // 1) be matrices with no padding, and
+ // 2) have an allowed element type.
+ return output_shape.element_type() == F32 &&
+ IsRank2WithNoPadding(lhs_shape) && IsRank2WithNoPadding(rhs_shape) &&
+ IsRank2WithNoPadding(output_shape);
+}
+} // namespace
+
+bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
+ legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ if (!flags->xla_cpu_use_eigen) {
+ return false;
+ }
+
+ // For certain types of Dot, we can call Eigen
+ if (hlo.opcode() == HloOpcode::kDot) {
+ const Shape& lhs_shape = hlo.operand(0)->shape();
+ const Shape& rhs_shape = hlo.operand(1)->shape();
+
+ if (ShapeUtil::HasZeroElements(lhs_shape) ||
+ ShapeUtil::HasZeroElements(rhs_shape)) {
+ return false;
+ }
+
+ // If gemm can accept the operand shapes, use it rather than a custom
+ // kernel.
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
+ // The size of the reduction dimension should match. The shape inference
+ // guarantees this invariant, so the check here is for programming
+ // errors.
+ CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
+ return true;
+ }
+ }
+
+ if (hlo.opcode() == HloOpcode::kFusion &&
+ hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
+ hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
+ const Shape& lhs_shape = hlo.operand(0)->shape();
+ const Shape& rhs_shape = hlo.operand(1)->shape();
+ if (ShapeUtil::HasZeroElements(lhs_shape) ||
+ ShapeUtil::HasZeroElements(rhs_shape)) {
+ return false;
+ }
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
new file mode 100644
index 0000000000..d48646d116
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h
@@ -0,0 +1,32 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+namespace cpu {
+
+bool PotentiallyImplementedAsEigenConvolution(
+ const HloInstruction& convolution);
+
+bool PotentiallyImplementedAsEigenDot(const HloInstruction& dot);
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
new file mode 100644
index 0000000000..7a839169fc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -0,0 +1,1774 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
+
+#include <stddef.h>
+#include <stdint.h>
+#include <algorithm>
+#include <iterator>
+#include <limits>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Constants.h"
+#include "external/llvm/include/llvm/IR/GlobalVariable.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Intrinsics.h"
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
+#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+
+using llvm_ir::SetToFirstInsertPoint;
+
+namespace cpu {
+
+IrEmitter::IrEmitter(
+ const HloModule& hlo_module, const HloModuleConfig& hlo_module_config,
+ const BufferAssignment& assignment, llvm::Module* llvm_module,
+ const std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx)
+ : assignment_(assignment),
+ module_(llvm_module),
+ arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
+ ir_builder_(llvm_module->getContext()),
+ hlo_to_profile_idx_(hlo_to_profile_idx),
+ alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
+ hlo_module_config_(hlo_module_config) {
+ llvm::FastMathFlags fast_math_flags;
+ llvm_ir::SetFastMathFlags(&fast_math_flags);
+ ir_builder_.setFastMathFlags(fast_math_flags);
+}
+
+StatusOr<llvm::Function*> IrEmitter::EmitComputation(
+ HloComputation* computation, const string& function_name_prefix,
+ bool is_entry_computation,
+ std::vector<const HloInstruction*>* instruction_order) {
+ string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
+ VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]";
+ InitializeIrFunction(function_name, is_entry_computation);
+ // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
+ // readcyclecounter if it is unavailable.
+ bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
+ arch_type_ == llvm::Triple::ArchType::x86_64;
+ profiling_state_ = ProfilingState(is_entry_computation, use_rdtscp,
+ GetProfileCountersArgument());
+ if (instruction_order != nullptr) {
+ TF_RETURN_IF_ERROR(computation->root_instruction()->AcceptOrdered(
+ this, *instruction_order));
+ } else {
+ TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
+ }
+ InsertOrDie(&emitted_functions_, computation, compute_function_);
+
+ return compute_function_;
+}
+
+static llvm::Argument* GetArg(llvm::Function* f, int idx) {
+ llvm::Function::arg_iterator arg_iter = f->arg_begin();
+ std::advance(arg_iter, idx);
+ return &*arg_iter;
+}
+
+void IrEmitter::InitializeIrFunction(const string& function_name,
+ bool is_entry_computation) {
+ // The function signature is:
+ // void function(i8* retval, i8* run_options, i8** params, i8** temps,
+ // i64* prof_counters)
+ //
+ // retval: points to the returned value.
+ // params: address of an array with pointers to parameters.
+ // temps: address of an array with pointers to temporary buffers.
+ //
+ // Therefore, the generated function's signature (FunctionType) is statically
+ // determined - parameter unpacking is done in code generated into the
+ // function, rather than by a prologue dictated by the platform ABI.
+ //
+ // /--------------\
+ // retval ----------> | return value |
+ // \--------------/
+ //
+ // /-------------------------------\
+ // run_options -----> | xla::ExecutableRunOptions |
+ // \-------------------------------/
+ //
+ // /---------------------------------------------\
+ // params --------> | param 0 | param 1 | ..... | param N-1 |
+ // | addr | addr | | addr |
+ // \---------------------------------------------/
+ // | | |
+ // | | |
+ // V V V
+ // /---------\ /---------\ /-----------\
+ // | param 0 | | param 1 | | param N-1 |
+ // \---------/ \---------/ \-----------/
+ //
+ // /---------------------------------------------\
+ // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
+ // | addr | addr | | addr |
+ // \---------------------------------------------/
+ // | | |
+ // | | |
+ // V V V
+ // /---------\ /---------\ /-----------\
+ // | temp 0 | | temp 1 | | temp N-1 |
+ // \---------/ \---------/ \-----------/
+ //
+ // /---------------------------------------------\
+ // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
+ // (elided for aot) \---------------------------------------------/
+
+ // Even though the type of params and temps is void** in the host's view, in
+ // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
+ // to use GEPs to unravel the indirection layers.
+ llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
+ llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
+ llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext());
+ std::vector<llvm::Type*> compute_function_params(
+ {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
+ if (hlo_to_profile_idx_) {
+ compute_function_params.push_back(i64_ptr_type);
+ }
+ llvm::FunctionType* compute_function_type = llvm::FunctionType::get(
+ /*Result=*/llvm::Type::getVoidTy(module_->getContext()),
+ /*Params=*/compute_function_params,
+ /*isVarArg=*/false);
+
+ // Functions with local linkage get an inlining bonus. Because we know
+ // a-priori that embedded functions (non-entry functions) will not have its
+ // name resolved, give it local linkage.
+ llvm::Function::LinkageTypes linkage =
+ is_entry_computation ? llvm::GlobalValue::ExternalLinkage
+ : llvm::GlobalValue::InternalLinkage;
+ compute_function_ = llvm::Function::Create(/*Ty=*/compute_function_type,
+ /*Linkage=*/linkage,
+ /*Name=*/function_name.c_str(),
+ /*Module=*/module_);
+ compute_function_->setCallingConv(llvm::CallingConv::C);
+
+ // Set meaningful names for the function's arguments: useful for debugging.
+ llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin();
+ arg_iter->setName("retval");
+ (++arg_iter)->setName("run_options");
+ (++arg_iter)->setName("params");
+ (++arg_iter)->setName("temps");
+ if (hlo_to_profile_idx_) {
+ (++arg_iter)->setName("prof_counters");
+ }
+
+ // We know a-priori that the function arguments are guaranteed to point to
+ // disjoint objects.
+ llvm::Argument* retval = GetResultArgument();
+ for (llvm::Argument& argument : compute_function_->args()) {
+ // However, the return buffer aliases the temporaries and thus cannot be
+ // marked noalias.
+ if (&argument == retval) {
+ continue;
+ }
+ compute_function_->setDoesNotAlias(argument.getArgNo() + 1);
+ }
+
+ ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(
+ /*Context=*/module_->getContext(),
+ /*Name=*/"entry",
+ /*Parent=*/compute_function_));
+}
+
+IrEmitter::~IrEmitter() {}
+
+Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
+ VLOG(2) << "HandleBitcast: " << bitcast->ToString();
+ emitted_value_[bitcast] = ir_builder_.CreateBitCast(
+ GetEmittedValueFor(bitcast->operand(0)),
+ IrShapeType(bitcast->shape())->getPointerTo(), bitcast->name().c_str());
+ return Status::OK();
+}
+
+Status IrEmitter::HandleConstant(HloInstruction* constant,
+ const Literal& literal) {
+ VLOG(2) << "HandleConstant: " << constant->ToString();
+ llvm::Constant* initializer =
+ llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
+ llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
+ /*Module=*/*module_,
+ /*Type=*/initializer->getType(),
+ /*isConstant=*/true,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/initializer,
+ /*Name=*/"");
+ emitted_value_[constant] = global_for_const;
+ VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const);
+ VLOG(2) << " its type: "
+ << llvm_ir::DumpToString(*global_for_const->getType());
+ return Status::OK();
+}
+
+Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) {
+ if (ShapeUtil::IsTuple(copy->shape())) {
+ // kCopy shallow copies a tuple so just memcpy the top-level buffer.
+ TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy));
+ emitted_value_[copy] = copy_value;
+ return EmitMemcpy(*operand, *copy);
+ } else {
+ // Use the elemental emitter for non-tuple shapes.
+ return DefaultAction(copy);
+ }
+}
+
+// Calculate the alignment of a buffer with a particular size.
+int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) {
+ // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on
+ // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for
+ // allocations smaller than 16 bytes and at least alignment 16 for allocations
+ // greater than or equal to 16 bytes. N.B. We could improve on this lower
+ // bound by explicitly allocating the memory with posix_memalign. This is
+ // complicated by our desire to allow parameter buffers created by clients to
+ // be consumed directly by the JIT.
+ if (buffer_size == 0) {
+ // No need to align empty buffers.
+ return 1;
+ }
+ int pointer_size = module_->getDataLayout().getPointerSize();
+ int buffer_alignment = buffer_size >= 16 ? 2 * pointer_size : 8;
+ DCHECK_GT(buffer_alignment, 0);
+
+ return buffer_alignment;
+}
+
+// Calculate the alignment of a buffer allocated for a given primitive type.
+int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
+ int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
+ DCHECK_GE(buffer_size, 0);
+ DCHECK_LE(buffer_size, SIZE_MAX);
+
+ return MinimumAlignmentForBufferSize(buffer_size);
+}
+
+int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
+ return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
+}
+
+// Calculate the alignment of a buffer allocated for a given shape.
+int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
+ int64 buffer_size = ByteSizeOf(shape);
+ DCHECK_GE(buffer_size, 0);
+ DCHECK_LE(buffer_size, SIZE_MAX);
+
+ return MinimumAlignmentForBufferSize(buffer_size);
+}
+
+void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
+ const Shape& shape) {
+ int alignment = MinimumAlignmentForShape(shape);
+ if (alignment > 1) {
+ llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
+ }
+}
+
+void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
+ int64 buffer_size) {
+ int alignment = MinimumAlignmentForBufferSize(buffer_size);
+ if (alignment > 1) {
+ llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
+ }
+}
+
+void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ const Shape& shape) {
+ AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
+}
+
+void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ int64 buffer_size) {
+ if (buffer_size > 0) {
+ llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
+ }
+}
+
+Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) {
+ // A tuple is an array of pointers, one for each operand. Each pointer points
+ // to the output buffer of its corresponding operand. A GetTupleElement
+ // instruction forwards a pointer to the tuple element buffer at the given
+ // index.
+ const Shape& shape = get_tuple_element->shape();
+ emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
+ shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
+ GetEmittedValueFor(operand), &ir_builder_);
+ return Status::OK();
+}
+
+Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ TF_RET_CHECK(pred->shape().element_type() == PRED);
+
+ if (ShapeUtil::IsTuple(select->shape())) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
+ EmitTargetAddressForOp(select));
+ llvm_ir::EmitTupleSelect(llvm_ir::IrArray(output_address, select->shape()),
+ GetIrArrayForOp(pred), GetEmittedValueFor(on_true),
+ GetEmittedValueFor(on_false), &ir_builder_);
+ emitted_value_[select] = output_address;
+ return Status::OK();
+ }
+
+ return DefaultAction(select);
+}
+
+Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
+ VLOG(2) << "HandleInfeed: " << infeed->ToString();
+
+ // The signature of the acquire infeed buffer function is:
+ //
+ // (void*)(int32 length);
+ llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
+ llvm::Type* int32_type = ir_builder_.getInt32Ty();
+ llvm::FunctionType* acquire_type =
+ llvm::FunctionType::get(i8_ptr_type, {int32_type},
+ /*isVarArg=*/false);
+
+ llvm::Function* acquire_func =
+ llvm::cast<llvm::Function>(module_->getOrInsertFunction(
+ runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type));
+ acquire_func->setCallingConv(llvm::CallingConv::C);
+
+ // The signature of the release infeed buffer function is:
+ //
+ // (void)(int32 length, void* buffer);
+ llvm::FunctionType* release_type = llvm::FunctionType::get(
+ ir_builder_.getVoidTy(), {int32_type, i8_ptr_type},
+ /*isVarArg=*/false);
+
+ llvm::Function* release_func =
+ llvm::cast<llvm::Function>(module_->getOrInsertFunction(
+ runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type));
+ release_func->setCallingConv(llvm::CallingConv::C);
+
+ const Shape& shape = infeed->shape();
+ int64 length = ByteSizeOf(shape);
+ if (length > std::numeric_limits<int32>::max()) {
+ return InvalidArgument("infeed buffer length %lld is too large", length);
+ }
+ int32 length_32 = static_cast<int32>(length);
+
+ llvm::Value* acquired_pointer =
+ ir_builder_.CreateCall(acquire_func, {ir_builder_.getInt32(length_32)});
+
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(infeed));
+
+ ir_builder_.CreateMemCpy(target_address, acquired_pointer, length_32, 1);
+
+ ir_builder_.CreateCall(release_func,
+ {ir_builder_.getInt32(length_32), acquired_pointer});
+
+ emitted_value_[infeed] = target_address;
+
+ return Status::OK();
+}
+
+Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) {
+ // TODO(b/26783907): Implement sort on CPU.
+ return Unimplemented("sort");
+}
+
+Status IrEmitter::HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(tuple));
+ std::vector<llvm::Value*> base_ptrs;
+ for (auto operand : operands) {
+ base_ptrs.push_back(GetEmittedValueFor(operand));
+ }
+ llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()),
+ base_ptrs, &ir_builder_);
+ emitted_value_[tuple] = target_address;
+ return Status::OK();
+}
+
+Status IrEmitter::HandleMap(
+ HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) {
+ // The called computation should have been emitted previously.
+ llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
+
+ return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function](
+ const llvm_ir::IrArray::Index& index) {
+ std::vector<llvm::Value*> parameter_addresses;
+ for (const HloInstruction* operand : operands) {
+ const llvm_ir::IrArray& array = GetIrArrayForOp(operand);
+ parameter_addresses.push_back(
+ array.EmitArrayElementAddress(index, &ir_builder_));
+ }
+ return EmitElementFunctionCall(mapped_ir_function, map->shape(),
+ parameter_addresses, "map_function");
+ });
+}
+
+Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window,
+ HloInstruction* operand,
+ const Window& window,
+ HloComputation* function) {
+ TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
+ /*instruction=*/*reduce_window, /*operands=*/{operand},
+ /*supported_types=*/{F32}));
+
+ // TODO(b/31410564): Implement dilation for reduce-window.
+ if (window_util::HasDilation(window)) {
+ return Unimplemented(
+ "Dilation for reduce-window not implemented on CPU. See b/31410564.");
+ }
+
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
+
+ // Pseudo code for reduce window:
+ //
+ // for (coordinates O in the output)
+ // value = init_value;
+ // for (coordinates W in the window)
+ // for each index i:
+ // input coordinates I_i = O_i * stride_i + W_i - pad_low_i
+ // if I within bounds of input:
+ // value = function(value, input(I));
+ // output(O) = value;
+ //
+ // This is completely un-optimized and just here to have something
+ // that works.
+ return EmitTargetElementLoop(
+ reduce_window, [this, reduce_window, operand, window,
+ reducer_function](const llvm_ir::IrArray::Index& index) {
+ // We fold inputs into the accumulator and initialize it to
+ // the initial value on the reduce_window.
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
+ "reduce_window_accumulator_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(operand_element_type));
+ ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
+ reduce_window->operand(1))),
+ accumulator_address);
+
+ llvm_ir::ForLoopNest loops(&ir_builder_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ CHECK_EQ(window_index.size(), index.size());
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ llvm_ir::IrArray::Index input_index(index.size());
+ llvm::Value* in_bounds_condition = nullptr;
+ for (int64 i = 0; i < index.size(); ++i) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
+ input_index[i] = ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
+ ir_builder_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to check if 0 <= input_index[i] < bound, as
+ // otherwise we are in the padding so that we can skip the
+ // computation. That is equivalent to input_index[i] < bound
+ // as an *unsigned* comparison, since a negative value will
+ // wrap to a large positive value.
+ llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension(
+ operand->shape(), i)));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = index_condition;
+ } else {
+ in_bounds_condition =
+ ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ }
+ }
+ CHECK(in_bounds_condition != nullptr);
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+
+ // We are not in the padding, so carry out the computation.
+ llvm_ir::IrArray input_array(GetIrArrayForOp(operand));
+ llvm::Value* input_value_address =
+ input_array.EmitArrayElementAddress(input_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce_window->shape(),
+ {accumulator_address, input_value_address}, "reducer_function");
+ ir_builder_.CreateStore(result, accumulator_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_address);
+ });
+}
+
+Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
+ CHECK_EQ(select_and_scatter->operand_count(), 3);
+ const auto operand = select_and_scatter->operand(0);
+ const auto source = select_and_scatter->operand(1);
+ const auto init_value = select_and_scatter->operand(2);
+ const Window& window = select_and_scatter->window();
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ const int64 rank = ShapeUtil::Rank(operand->shape());
+ CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
+ CHECK_EQ(rank, window.dimensions_size());
+
+ // TODO(b/31410564): Implement dilation for select-and-scatter.
+ if (window_util::HasDilation(window)) {
+ return Unimplemented(
+ "Dilation for select-and-scatter not implemented on CPU. "
+ "See b/31410564.");
+ }
+
+ // The select and scatter computations should have been emitted previously.
+ llvm::Function* select_function =
+ FindOrDie(emitted_functions_, select_and_scatter->select());
+ llvm::Function* scatter_function =
+ FindOrDie(emitted_functions_, select_and_scatter->scatter());
+
+ // Pseudo code for select-and-scatter:
+ //
+ // initialized_flag is initially off for every window, and is turned on after
+ // the first iteration is completed and the first operand value is selected.
+ //
+ // output(*) = init_value
+ // for (coordinates S in the source) {
+ // initialized_flag = false
+ // for (coordinates W in the window) {
+ // I = S * stride + W - pad_low
+ // if I within bounds of operand:
+ // if !initialized_flag or select(selected_value, operand(I)) == false:
+ // selected_value = operand(I)
+ // selected_index = I
+ // initialized_flag = true
+ // }
+ // output(selected_index) = scatter(output(selected_index), source(S))
+ // }
+ //
+
+ // Initialize the output array with the given init_value.
+ TF_RETURN_IF_ERROR(EmitTargetElementLoop(
+ select_and_scatter,
+ [this, init_value](const llvm_ir::IrArray::Index& target_index) {
+ llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
+ return ir_builder_.CreateLoad(init_value_addr);
+ }));
+
+ // Create a loop to iterate over the source array to scatter to the output.
+ llvm_ir::ForLoopNest source_loops(&ir_builder_);
+ const llvm_ir::IrArray::Index source_index =
+ source_loops.AddLoopsForShape(source->shape(), "source");
+ SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(),
+ &ir_builder_);
+
+ // Allocate space to keep the currently selected value, its index, and
+ // the boolean initialized_flag, which is initially set to false.
+ llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
+ "selected_value_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(operand_element_type));
+ llvm::Value* selected_index_address =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
+ "selected_index_address", &ir_builder_);
+ llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
+ ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address);
+
+ // Create the inner loop to iterate over the window.
+ llvm_ir::ForLoopNest window_loops(&ir_builder_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
+ &ir_builder_);
+
+ // Compute the operand index to visit and evaluate the condition whether the
+ // operand index is within the bounds. The unsigned comparison includes
+ // checking whether the operand index >= 0.
+ llvm_ir::IrArray::Index operand_index(source_index.size());
+ llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] = ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
+ ir_builder_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ operand_index[i],
+ ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition =
+ ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ }
+ CHECK(in_bounds_condition != nullptr);
+
+ // Only need to do something if the operand index is within the bounds. First
+ // check if the initialized_flag is set.
+ llvm_ir::LlvmIfData if_in_bounds =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
+ llvm_ir::LlvmIfData if_initialized =
+ llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address),
+ "initialized", &ir_builder_);
+
+ // If the initialized_flag is false, initialize the selected value and index
+ // with the currently visiting operand.
+ SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
+ const auto save_operand_index = [&](
+ const llvm_ir::IrArray::Index& operand_index) {
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
+ selected_index_address, {ir_builder_.getInt32(i)});
+ ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
+ }
+ };
+ llvm_ir::IrArray operand_array(GetIrArrayForOp(operand));
+ llvm::Value* operand_data =
+ operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
+ ir_builder_.CreateStore(operand_data, selected_value_address);
+ save_operand_index(operand_index);
+ ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address);
+
+ // If the initialized_flag is true, call the `select` function to potentially
+ // update the selected value and index with the currently visiting operand.
+ SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
+ const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
+ llvm::Value* operand_address =
+ operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ select_function, output_shape, {selected_value_address, operand_address},
+ "select_function");
+
+ // If the 'select' function returns false, update the selected value and the
+ // index to the currently visiting operand.
+ llvm::Value* cond = ir_builder_.CreateICmpNE(
+ result, llvm::ConstantInt::get(
+ llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
+ "boolean_predicate");
+ llvm_ir::LlvmIfData if_select_lhs =
+ llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
+ SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
+ ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
+ selected_value_address);
+ save_operand_index(operand_index);
+
+ // After iterating over the window elements, scatter the source element to
+ // the selected index of the output. The value we store at the output
+ // location is computed by calling the `scatter` function with the source
+ // value and the current output value.
+ SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
+ &ir_builder_);
+ llvm_ir::IrArray::Index selected_index;
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
+ selected_index_address, {ir_builder_.getInt32(i)});
+ selected_index.push_back(
+ ir_builder_.CreateLoad(selected_index_address_slot));
+ }
+ llvm_ir::IrArray source_array(GetIrArrayForOp(source));
+ llvm::Value* source_value_address =
+ source_array.EmitArrayElementAddress(source_index, &ir_builder_);
+ llvm_ir::IrArray output_array(GetIrArrayForOp(select_and_scatter));
+ llvm::Value* output_value_address =
+ output_array.EmitArrayElementAddress(selected_index, &ir_builder_);
+ llvm::Value* scatter_value = EmitElementFunctionCall(
+ scatter_function, source->shape(),
+ {output_value_address, source_value_address}, "scatter_function");
+ output_array.EmitWriteArrayElement(selected_index, scatter_value,
+ &ir_builder_);
+
+ SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(),
+ &ir_builder_);
+ return Status::OK();
+}
+
+Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
+ /*instruction=*/*dot, /*operands=*/{lhs, rhs},
+ /*supported_types=*/{F32, F64}));
+
+ llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs));
+ llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs));
+
+ Shape target_shape = dot->shape();
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(dot));
+ llvm_ir::IrArray target_array(target_address, target_shape);
+ AddAliasingInformationToIrArray(*dot, &target_array);
+
+ VLOG(2) << "HandleDot: ";
+ VLOG(2) << " lhs operand: "
+ << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
+ VLOG(2) << " rhs operand: "
+ << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
+ VLOG(2) << " target: "
+ << llvm_ir::DumpToString(*target_array.GetBasePointer());
+
+ // Dot operation is complicated so we delegate to a helper class.
+ TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
+ *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
+ lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_));
+
+ emitted_value_[dot] = target_address;
+ return Status::OK();
+}
+
+Status IrEmitter::HandleConvolution(HloInstruction* convolution,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window) {
+ TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
+ /*instruction=*/*convolution, /*operands=*/{lhs, rhs},
+ /*supported_types=*/{F32}));
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ if (PotentiallyImplementedAsEigenConvolution(*convolution)) {
+ const Shape& lhs_shape = lhs->shape();
+ const Shape& rhs_shape = rhs->shape();
+ const Shape& convolution_shape = convolution->shape();
+ // The input, kernel and output agree with respect to layout.
+ if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
+ LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
+ LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
+ llvm::Value* lhs_address = GetEmittedValueFor(lhs);
+ llvm::Value* rhs_address = GetEmittedValueFor(rhs);
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(convolution));
+
+ const ConvolutionDimensionNumbers& dnums =
+ convolution->convolution_dimension_numbers();
+
+ // Input tensor.
+ const Shape& input_shape = convolution->operand(0)->shape();
+ int64 input_batch = input_shape.dimensions(dnums.batch_dimension());
+ int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0));
+ int64 input_cols = input_shape.dimensions(dnums.spatial_dimensions(1));
+ int64 input_channels = input_shape.dimensions(dnums.feature_dimension());
+
+ // Kernel tensor.
+ const Shape& kernel_shape = convolution->operand(1)->shape();
+ int64 kernel_rows =
+ kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
+ int64 kernel_cols =
+ kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
+ int64 kernel_channels =
+ kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
+ int64 kernel_filters =
+ kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
+
+ // Output tensor.
+ const Shape& convolution_shape = convolution->shape();
+ int64 output_rows =
+ convolution_shape.dimensions(dnums.spatial_dimensions(0));
+ int64 output_cols =
+ convolution_shape.dimensions(dnums.spatial_dimensions(1));
+
+ // Extract the window stride for the convolution.
+ const Window& window = convolution->window();
+ int64 row_stride = window.dimensions(0).stride();
+ int64 col_stride = window.dimensions(1).stride();
+
+ int64 padding_top = window.dimensions(0).padding_low();
+ int64 padding_bottom = window.dimensions(0).padding_high();
+ int64 padding_left = window.dimensions(1).padding_low();
+ int64 padding_right = window.dimensions(1).padding_high();
+
+ int64 lhs_row_dilation = window.dimensions(0).base_dilation();
+ int64 lhs_col_dilation = window.dimensions(1).base_dilation();
+ int64 rhs_row_dilation = window.dimensions(0).window_dilation();
+ int64 rhs_col_dilation = window.dimensions(1).window_dilation();
+
+ // Args have been computed, make the call.
+ llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo();
+ llvm::Type* int64_type = ir_builder_.getInt64Ty();
+ llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
+ llvm::FunctionType* conv_type = llvm::FunctionType::get(
+ ir_builder_.getVoidTy(),
+ {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
+ int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type,
+ int64_type, int64_type, int64_type, int64_type},
+ /*isVarArg=*/false);
+ legacy_flags::CpuRuntimeFlags* flags = legacy_flags::GetCpuRuntimeFlags();
+ const char* fn_name =
+ (flags->xla_cpu_multi_thread_eigen
+ ? runtime::kEigenConvF32SymbolName
+ : runtime::kEigenSingleThreadedConvF32SymbolName);
+ llvm::Function* conv_func = llvm::cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, conv_type));
+ conv_func->setCallingConv(llvm::CallingConv::C);
+ conv_func->setDoesNotThrow();
+ conv_func->setOnlyAccessesArgMemory();
+ ir_builder_.CreateCall(
+ conv_func,
+ {
+ GetExecutableRunOptionsArgument(),
+ ir_builder_.CreateBitCast(target_address, float_ptr_type),
+ ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
+ ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
+ ir_builder_.getInt64(input_batch),
+ ir_builder_.getInt64(input_rows),
+ ir_builder_.getInt64(input_cols),
+ ir_builder_.getInt64(input_channels),
+ ir_builder_.getInt64(kernel_rows),
+ ir_builder_.getInt64(kernel_cols),
+ ir_builder_.getInt64(kernel_channels),
+ ir_builder_.getInt64(kernel_filters),
+ ir_builder_.getInt64(output_rows),
+ ir_builder_.getInt64(output_cols),
+ ir_builder_.getInt64(row_stride),
+ ir_builder_.getInt64(col_stride),
+ ir_builder_.getInt64(padding_top),
+ ir_builder_.getInt64(padding_bottom),
+ ir_builder_.getInt64(padding_left),
+ ir_builder_.getInt64(padding_right),
+ ir_builder_.getInt64(lhs_row_dilation),
+ ir_builder_.getInt64(lhs_col_dilation),
+ ir_builder_.getInt64(rhs_row_dilation),
+ ir_builder_.getInt64(rhs_col_dilation),
+ });
+ emitted_value_[convolution] = target_address;
+
+ return Status::OK();
+ }
+ }
+
+ // This is a completely un-optimized version of convolution just to
+ // have an early version that works. E.g. the input index and
+ // padding calculation is not hoisted out of the inner loop.
+ //
+ // See the description of convolution in the XLA documentation for the pseudo
+ // code for convolution.
+ return EmitTargetElementLoop(
+ convolution, [this, convolution, lhs, rhs, window,
+ dnums](const llvm_ir::IrArray::Index& index) {
+ int num_spatial_dims = dnums.spatial_dimensions_size();
+ std::vector<llvm::Value*> output_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ output_spatial[i] = index[dnums.spatial_dimensions(i)];
+ }
+ llvm::Value* output_feature = index[dnums.feature_dimension()];
+ llvm::Value* batch = index[dnums.batch_dimension()];
+
+ // We will accumulate the products into this sum to calculate
+ // the output entry at the given index.
+ PrimitiveType lhs_element_type = lhs->shape().element_type();
+ llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_),
+ "convolution_sum_address", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(lhs_element_type));
+ ir_builder_.CreateStore(
+ llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), sum_address);
+
+ llvm_ir::ForLoopNest loops(&ir_builder_);
+ std::vector<llvm::Value*> kernel_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_spatial[i] =
+ loops
+ .AddLoop(0, rhs->shape().dimensions(
+ dnums.kernel_spatial_dimensions(i)),
+ tensorflow::strings::StrCat("k", i))
+ ->GetIndVarValue();
+ }
+ llvm::Value* input_feature =
+ loops
+ .AddLoop(0, lhs->shape().dimensions(dnums.feature_dimension()),
+ "iz")
+ ->GetIndVarValue();
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Calculate the spatial index in the input array, taking striding,
+ // dilation and padding into account. An index in the padding will be
+ // out of the bounds of the array.
+ const auto calculate_input_index = [this](
+ llvm::Value* output_index, llvm::Value* kernel_index,
+ const WindowDimension& window_dim) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ output_index, ir_builder_.getInt64(window_dim.stride()));
+ llvm::Value* dilated_kernel_index = ir_builder_.CreateNSWMul(
+ kernel_index, ir_builder_.getInt64(window_dim.window_dilation()));
+ return ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, dilated_kernel_index),
+ ir_builder_.getInt64(window_dim.padding_low()));
+ };
+ std::vector<llvm::Value*> input_spatial(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] = calculate_input_index(
+ output_spatial[i], kernel_spatial[i], window.dimensions(i));
+ }
+
+ // We need to check if 0 <= input dim < bound, as otherwise we are in
+ // the padding so that we can skip the computation. That is equivalent
+ // to input dim < bound as an *unsigned* comparison, since a negative
+ // value will wrap to a large positive value. The input dim is dilated,
+ // so we need to dilate the bound as well to match.
+
+ // Also need to check that the input coordinates are not in one of the
+ // holes created by base dilation.
+ const auto not_in_hole = [&](llvm::Value* input_index,
+ int64 base_dilation) {
+ llvm::Value* remainder = ir_builder_.CreateSRem(
+ input_index, ir_builder_.getInt64(base_dilation));
+ return ir_builder_.CreateICmpEQ(remainder, ir_builder_.getInt64(0));
+ };
+
+ llvm::Value* in_bounds_condition = nullptr;
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ llvm::ConstantInt* input_bound =
+ ir_builder_.getInt64(window_util::DilatedBound(
+ lhs->shape().dimensions(dnums.spatial_dimensions(i)),
+ window.dimensions(i).base_dilation()));
+ llvm::Value* dim_in_bound =
+ ir_builder_.CreateICmpULT(input_spatial[i], input_bound);
+ llvm::Value* dim_not_in_hole = not_in_hole(
+ input_spatial[i], window.dimensions(i).base_dilation());
+ llvm::Value* dim_ok =
+ ir_builder_.CreateAnd(dim_in_bound, dim_not_in_hole);
+ in_bounds_condition =
+ in_bounds_condition
+ ? ir_builder_.CreateAnd(in_bounds_condition, dim_ok)
+ : dim_ok;
+ }
+
+ // Now we need to map the dilated base coordinates back to the actual
+ // data indices on the lhs.
+ const auto undilate = [&](llvm::Value* input_index,
+ int64 base_dilation) {
+ return ir_builder_.CreateSDiv(input_index,
+ ir_builder_.getInt64(base_dilation));
+ };
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial[i] =
+ undilate(input_spatial[i], window.dimensions(i).base_dilation());
+ }
+
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ in_bounds_condition, "in-bounds", &ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+
+ // We are not in the padding, so carry out the computation.
+ int num_dims = num_spatial_dims + 2;
+ llvm_ir::IrArray::Index input_index(num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_index[dnums.spatial_dimensions(i)] = input_spatial[i];
+ }
+ input_index[dnums.feature_dimension()] = input_feature;
+ input_index[dnums.batch_dimension()] = batch;
+
+ llvm_ir::IrArray kernel_array(GetIrArrayForOp(rhs));
+ llvm_ir::IrArray::Index kernel_index(num_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i];
+ }
+ kernel_index[dnums.kernel_input_feature_dimension()] = input_feature;
+ kernel_index[dnums.kernel_output_feature_dimension()] = output_feature;
+
+ llvm_ir::IrArray input_array(GetIrArrayForOp(lhs));
+ llvm::Value* product = ir_builder_.CreateFMul(
+ input_array.EmitReadArrayElement(input_index, &ir_builder_),
+ kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_));
+ llvm::Value* sum = ir_builder_.CreateFAdd(
+ ir_builder_.CreateLoad(sum_address), product);
+ ir_builder_.CreateStore(sum, sum_address);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(sum_address);
+ });
+}
+
+Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
+ // TODO(b/33011107): Support cross replica sum on CPU.
+ return Unimplemented(
+ "Cross replica sum not implemented on CPU. See b/33011107.");
+}
+
+Status IrEmitter::HandleParameter(HloInstruction* parameter) {
+ VLOG(2) << "HandleParameter: " << parameter->ToString();
+ auto param_number = parameter->parameter_number();
+ auto param_shape = parameter->shape();
+
+ // We have to access the parameter at offset param_number in the params
+ // array. The code generated here is equivalent to this C code:
+ //
+ // i8* param_address_untyped = params[param_number];
+ // Param* param_address_typed = (Param*)param_address_untyped;
+ //
+ // Where Param is the actual element type of the underlying buffer (for
+ // example, float for an XLA F32 element type).
+ llvm::Argument* params = GetArg(compute_function_, 2);
+ llvm::Value* param_address_offset =
+ llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_);
+ llvm::LoadInst* param_address_untyped =
+ ir_builder_.CreateLoad(param_address_offset);
+ llvm::Value* param_address_typed = ir_builder_.CreateBitCast(
+ param_address_untyped, IrShapeType(param_shape)->getPointerTo());
+ emitted_value_[parameter] = param_address_typed;
+
+ // Parameters of different types may not alias one another.
+ llvm_ir::SetTbaaForInstruction(param_address_untyped, param_shape,
+ /*is_pointer_to=*/true);
+ if (!ShapeUtil::IsOpaque(param_shape)) {
+ AttachAlignmentMetadataForLoad(param_address_untyped, param_shape);
+ AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape);
+ }
+
+ VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*param_address_typed);
+ return Status::OK();
+}
+
+Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) {
+ // The called computation should have been emitted previously.
+ llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
+ return EmitTargetElementLoop(
+ reduce, [this, reduce, arg, init_value, dimensions,
+ reducer_function](const llvm_ir::IrArray::Index& index) {
+ // Initialize an accumulator with init_value.
+ PrimitiveType accumulator_type = reduce->shape().element_type();
+ llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_),
+ "accumulator", &ir_builder_,
+ MinimumAlignmentForPrimitiveType(accumulator_type));
+ llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
+ llvm::Value* load_init_value = ir_builder_.CreateLoad(init_value_addr);
+ ir_builder_.CreateStore(load_init_value, accumulator_addr);
+
+ // The enclosing loops go over all the target elements. Now we have to
+ // compute the actual target element. For this, we build a new loop nest
+ // to iterate over all the reduction dimensions in the argument.
+ // AddLoopsForShapeOnDimensions will return an Index where induction
+ // Value*s are placed for each dimension in dimensions, and all the rest
+ // are nullptrs.
+ llvm_ir::ForLoopNest loops(&ir_builder_);
+ const llvm_ir::IrArray::Index reduced_dims_index =
+ loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
+ "reduction_dim");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Build a full index for the input argument, using reduced_dims_index
+ // as the base. In reduced_dims_index only the reduction dimensions are
+ // filled in. We fill in the rest of the dimensions with induction
+ // Value*s taken from 'index' which iterates over the target array.
+ // See the high-level description in the XLA documentation for details.
+ llvm_ir::IrArray arg_array(GetIrArrayForOp(arg));
+ llvm_ir::IrArray::Index input_index = reduced_dims_index;
+ llvm_ir::IrArray::Index::const_iterator it = index.begin();
+
+ for (int64 i = 0; i < input_index.size(); ++i) {
+ if (input_index[i] == nullptr) {
+ input_index[i] = *it++;
+ }
+ }
+ CHECK(index.end() == it);
+
+ // Apply the reduction function to the loaded value.
+ llvm::Value* input_address =
+ arg_array.EmitArrayElementAddress(input_index, &ir_builder_);
+ llvm::Value* result = EmitElementFunctionCall(
+ reducer_function, reduce->shape(),
+ {accumulator_addr, input_address}, "reduce_function");
+ ir_builder_.CreateStore(result, accumulator_addr);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_addr);
+ });
+}
+
+Status IrEmitter::HandleSend(HloInstruction* send) {
+ // TODO(b/33942983): Support Send/Recv on CPU.
+ return Unimplemented("Send is not implemented on CPU. See b/33942983.");
+}
+
+Status IrEmitter::HandleRecv(HloInstruction* recv) {
+ // TODO(b/33942983): Support Send/Recv on CPU.
+ return Unimplemented("Recv is not implemented on CPU. See b/33942983.");
+}
+
+Status IrEmitter::HandlePad(HloInstruction* pad) {
+ // First, fill in the padding value to all output elements.
+ TF_RETURN_IF_ERROR(EmitTargetElementLoop(
+ pad, [this, pad](const llvm_ir::IrArray::Index& target_index) {
+ const HloInstruction* padding_value = pad->operand(1);
+ llvm::Value* padding_value_addr = GetEmittedValueFor(padding_value);
+ return ir_builder_.CreateLoad(padding_value_addr);
+ }));
+
+ // Create a loop to iterate over the operand elements and update the output
+ // locations where the operand elements should be stored.
+ llvm_ir::ForLoopNest loops(&ir_builder_);
+ const HloInstruction* operand = pad->operand(0);
+ const llvm_ir::IrArray::Index operand_index =
+ loops.AddLoopsForShape(operand->shape(), "operand");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Load an element from the operand.
+ llvm_ir::IrArray operand_array(GetIrArrayForOp(operand));
+ llvm::Value* operand_data =
+ operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
+
+ // Compute the output index the operand element should be assigned to.
+ // output_index := edge_padding_low + operand_index * (interior_padding + 1)
+ const PaddingConfig& padding_config = pad->padding_config();
+ llvm_ir::IrArray::Index output_index;
+ for (int64 i = 0; i < operand_index.size(); ++i) {
+ llvm::Value* offset = ir_builder_.CreateMul(
+ operand_index[i],
+ ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() +
+ 1));
+ llvm::Value* index = ir_builder_.CreateAdd(
+ offset,
+ ir_builder_.getInt64(padding_config.dimensions(i).edge_padding_low()));
+ output_index.push_back(index);
+ }
+
+ // Store the operand element to the computed output location.
+ llvm_ir::IrArray output_array(GetIrArrayForOp(pad));
+ output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return Status::OK();
+}
+
+// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself.
+static const HloInstruction* StripTranspose(const HloInstruction& hlo) {
+ if (hlo.IsRank2Transpose()) {
+ return hlo.operand(0);
+ }
+ return &hlo;
+}
+
+Status IrEmitter::HandleFusion(HloInstruction* fusion) {
+ if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) {
+ const HloInstruction* dot = fusion->fused_expression_root();
+ DCHECK(dot->opcode() == HloOpcode::kDot);
+ const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
+ const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
+ DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
+ rhs_parameter->opcode() == HloOpcode::kParameter);
+ const HloInstruction* lhs =
+ fusion->operand(lhs_parameter->parameter_number());
+ const HloInstruction* rhs =
+ fusion->operand(rhs_parameter->parameter_number());
+
+ TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
+ /*instruction=*/*dot, /*operands=*/{lhs, rhs},
+ /*supported_types=*/{F32}));
+
+ llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs));
+ llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs));
+
+ Shape target_shape = fusion->shape();
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(fusion));
+ llvm_ir::IrArray target_array(target_address, target_shape);
+ AddAliasingInformationToIrArray(*fusion, &target_array);
+
+ VLOG(2) << "HandleFusion kTransposeDot: ";
+ VLOG(2) << " lhs operand: "
+ << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
+ VLOG(2) << " rhs operand: "
+ << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
+ VLOG(2) << " target: "
+ << llvm_ir::DumpToString(*target_array.GetBasePointer());
+
+ // Dot operation is complicated so we delegate to a helper class.
+ TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
+ *dot, dot->operand(0)->IsRank2Transpose(),
+ dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array,
+ GetExecutableRunOptionsArgument(), &ir_builder_));
+
+ emitted_value_[fusion] = target_address;
+ return Status::OK();
+ } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) {
+ std::vector<llvm_ir::IrArray> parameter_arrays;
+ for (HloInstruction* operand : fusion->operands()) {
+ parameter_arrays.push_back(GetIrArrayForOp(operand));
+ }
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
+ module_);
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
+
+ return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator());
+ } else {
+ return Unimplemented("Fusion kind not implemented on CPU");
+ }
+}
+
+Status IrEmitter::HandleCall(
+ HloInstruction* call, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) {
+ llvm::Function* call_ir_function = FindOrDie(emitted_functions_, computation);
+
+ std::vector<llvm::Value*> parameter_addresses;
+ for (HloInstruction* operand : operands) {
+ parameter_addresses.push_back(GetEmittedValueFor(operand));
+ }
+
+ TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
+ EmitTargetAddressForOp(call));
+
+ EmitArrayFunctionCallInto(call_ir_function, parameter_addresses,
+ output_address, computation->name());
+
+ emitted_value_[call] = output_address;
+ return Status::OK();
+}
+
+Status IrEmitter::HandleCustomCall(
+ HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) {
+ llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy();
+ llvm::AllocaInst* operands_alloca =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ i8_ptr_type, ir_builder_.getInt32(operands.size()),
+ "cc_operands_alloca", &ir_builder_);
+ for (int i = 0; i < operands.size(); ++i) {
+ const HloInstruction* operand = operands[i];
+ llvm::Value* operand_as_i8ptr =
+ ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
+ llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP(
+ operands_alloca, {ir_builder_.getInt32(i)});
+ ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
+ }
+ auto* custom_call_ir_function =
+ llvm::cast<llvm::Function>(module_->getOrInsertFunction(
+ llvm_ir::AsStringRef(custom_call_target),
+ llvm::FunctionType::get(
+ /*Result=*/ir_builder_.getVoidTy(),
+ /*Params=*/{i8_ptr_type, operands_alloca->getType()},
+ /*isVarArg=*/false)));
+
+ TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
+ EmitTargetAddressForOp(custom_call));
+ auto* output_address_arg =
+ ir_builder_.CreatePointerCast(output_address, i8_ptr_type);
+
+ ir_builder_.CreateCall(custom_call_ir_function,
+ {output_address_arg, operands_alloca});
+
+ emitted_value_[custom_call] = output_address;
+ return Status::OK();
+}
+
+Status IrEmitter::HandleWhile(HloInstruction* xla_while, HloInstruction* init,
+ HloComputation* condition, HloComputation* body) {
+ // Precondition: Condition computation must return a scalar bool.
+ TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
+ condition->root_instruction()->shape().element_type() == PRED)
+ << "While condition computation must return bool";
+ // Check that all while-related buffers share an allocation.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshape(
+ xla_while->shape(),
+ [this, &xla_while](const Shape& /*subshape*/,
+ const ShapeIndex& index) -> Status {
+ auto check = [this](const HloInstruction* a, const HloInstruction* b,
+ const ShapeIndex& index) {
+ BufferAllocation::Index index_a =
+ assignment_.GetUniqueAllocation(a, index)
+ .ConsumeValueOrDie()
+ ->index();
+ BufferAllocation::Index index_b =
+ assignment_.GetUniqueAllocation(b, index)
+ .ConsumeValueOrDie()
+ ->index();
+ if (index_a != index_b) {
+ return InternalError(
+ "instruction %s does not share allocation with "
+ "instruction %s ",
+ a->ToString().c_str(), b->ToString().c_str());
+ }
+ return Status::OK();
+ };
+ TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
+ TF_RETURN_IF_ERROR(check(
+ xla_while, xla_while->while_condition()->parameter_instruction(0),
+ index));
+ TF_RETURN_IF_ERROR(
+ check(xla_while, xla_while->while_body()->parameter_instruction(0),
+ index));
+ TF_RETURN_IF_ERROR(check(
+ xla_while, xla_while->while_body()->root_instruction(), index));
+ return Status::OK();
+ }));
+
+ // Set emitted value to that of 'init' with which it shares an allocation.
+ emitted_value_[xla_while] = GetEmittedValueFor(init);
+
+ // The called computation should have been emitted previously.
+ llvm::Function* condition_ir_function =
+ FindOrDie(emitted_functions_, condition);
+ llvm::Function* body_ir_function = FindOrDie(emitted_functions_, body);
+
+ // Generating:
+ // while (Condition(while_result)) {
+ // // CopyInsertion pass inserts copies which enable 'while_result' to
+ // // be passed back in as 'Body' parameter.
+ // while_result = Body(while_result); // Insert
+ // }
+
+ // Terminates the current block with a branch to a while header.
+ llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
+ module_->getContext(), "while_header", compute_function_);
+ ir_builder_.CreateBr(header_bb);
+ ir_builder_.SetInsertPoint(header_bb);
+
+ // Calls the condition function to determine whether to proceed with the
+ // body. It must return a bool, so use the scalar call form.
+ llvm::Value* while_result = GetEmittedValueFor(xla_while);
+ llvm::Value* while_condition = EmitElementFunctionCall(
+ condition_ir_function, condition->root_instruction()->shape(),
+ {while_result}, "condition_function");
+ llvm::Value* while_predicate = ir_builder_.CreateICmpNE(
+ while_condition,
+ llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
+ 0));
+
+ // Branches to the body or to the while exit depending on the condition.
+ llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
+ module_->getContext(), "while_body", compute_function_);
+ llvm::BasicBlock* exit_bb =
+ llvm::BasicBlock::Create(module_->getContext(), "while__exit");
+ ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb);
+
+ // Calls the body function from the body block.
+ ir_builder_.SetInsertPoint(body_bb);
+
+ // Calls the body function.
+ EmitArrayFunctionCallInto(body_ir_function, {while_result}, while_result,
+ "while_body");
+ // Finishes with a branch back to the header.
+ ir_builder_.CreateBr(header_bb);
+
+ // Adds the exit block to the function and sets the insert point there.
+ compute_function_->getBasicBlockList().push_back(exit_bb);
+ ir_builder_.SetInsertPoint(exit_bb);
+
+ return Status::OK();
+}
+
+Status IrEmitter::FinishVisit(HloInstruction* root) {
+ // When this method is called, we should have already emitted an IR value for
+ // the root (return) op. The IR value holds the address of the buffer holding
+ // the value. If the root is a constant or parameter, we perform a memcpy from
+ // this buffer to the retval buffer of the computation. Otherwise, there's
+ // nothing to do since the result was already written directly into the output
+ // buffer.
+ VLOG(2) << "FinishVisit root: " << root->ToString();
+ llvm::Value* root_value = GetEmittedValueFor(root);
+ VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value);
+
+ if (auto* prof_counter = GetProfileCounterFor(/*hlo=*/nullptr)) {
+ profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter);
+ }
+
+ ir_builder_.CreateRetVoid();
+ return Status::OK();
+}
+
+llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction* hlo) {
+ string counter_name;
+ size_t prof_counter_idx;
+ if (!hlo_to_profile_idx_) {
+ return nullptr;
+ }
+ if (hlo) {
+ auto it = hlo_to_profile_idx_->find(hlo);
+ if (it == hlo_to_profile_idx_->end()) {
+ return nullptr;
+ }
+
+ prof_counter_idx = it->second;
+ uintptr_t hlo_address = reinterpret_cast<uintptr_t>(hlo);
+ counter_name = tensorflow::strings::StrCat(
+ "prof_counter_0x",
+ tensorflow::strings::Hex(
+ hlo_address, tensorflow::strings::PadSpec(sizeof(hlo_address))));
+ } else {
+ prof_counter_idx = hlo_to_profile_idx_->size();
+ counter_name = "prof_counter_computation";
+ }
+ return ir_builder_.CreateGEP(GetProfileCountersArgument(),
+ ir_builder_.getInt64(prof_counter_idx),
+ llvm_ir::AsStringRef(counter_name));
+}
+
+void IrEmitter::ProfilingState::UpdateProfileCounter(
+ llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter,
+ llvm::Value* cycle_end, llvm::Value* cycle_start) {
+ auto* cycle_diff = ir_builder->CreateSub(cycle_end, cycle_start);
+ llvm::LoadInst* old_cycle_count =
+ ir_builder->CreateLoad(prof_counter, "old_cycle_count");
+ auto* new_cycle_count =
+ ir_builder->CreateAdd(cycle_diff, old_cycle_count, "new_cycle_count");
+ ir_builder->CreateStore(new_cycle_count, prof_counter);
+}
+
+llvm::Value* IrEmitter::ProfilingState::ReadCycleCounter(
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
+ if (use_rdtscp_) {
+ llvm::Function* func_llvm_readcyclecounter =
+ llvm::Intrinsic::getDeclaration(module,
+ llvm::Intrinsic::readcyclecounter);
+ return ir_builder->CreateCall(func_llvm_readcyclecounter);
+ }
+ llvm::Function* func_llvm_x86_rdtscp =
+ llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::x86_rdtscp);
+ if (!aux_i8ptr_) {
+ llvm::AllocaInst* rdtscp_aux = llvm_ir::EmitAllocaAtFunctionEntry(
+ ir_builder->getInt32Ty(), "rdtscp_aux", ir_builder);
+ aux_i8ptr_ =
+ ir_builder->CreateBitCast(rdtscp_aux, ir_builder->getInt8PtrTy());
+ }
+ llvm::ConstantInt* alloca_size = ir_builder->getInt64(4);
+ llvm::Function* func_llvm_lifetime_start =
+ llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_start);
+ ir_builder->CreateCall(func_llvm_lifetime_start, {alloca_size, aux_i8ptr_});
+ llvm::Value* rdtscp_call =
+ ir_builder->CreateCall(func_llvm_x86_rdtscp, aux_i8ptr_);
+ llvm::Function* func_llvm_lifetime_end =
+ llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::lifetime_end);
+ ir_builder->CreateCall(func_llvm_lifetime_end, {alloca_size, aux_i8ptr_});
+ return rdtscp_call;
+}
+
+void IrEmitter::ProfilingState::RecordCycleStart(llvm::IRBuilder<>* ir_builder,
+ HloInstruction* hlo) {
+ auto* cycle_start = ReadCycleCounter(ir_builder);
+ cycle_starts_[hlo] = cycle_start;
+ if (first_read_cycle_start_ == nullptr) {
+ first_read_cycle_start_ = cycle_start;
+ }
+}
+
+void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder,
+ HloInstruction* hlo,
+ llvm::Value* prof_counter) {
+ auto* cycle_end = ReadCycleCounter(ir_builder);
+ auto* cycle_start = cycle_starts_[hlo];
+ UpdateProfileCounter(ir_builder, prof_counter, cycle_end, cycle_start);
+ last_read_cycle_end_ = cycle_end;
+}
+
+void IrEmitter::ProfilingState::RecordCompleteComputation(
+ llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) {
+ if (is_entry_computation_ && last_read_cycle_end_ &&
+ first_read_cycle_start_) {
+ UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_,
+ first_read_cycle_start_);
+ }
+}
+
+Status IrEmitter::Preprocess(HloInstruction* hlo) {
+ if (hlo_to_profile_idx_ && hlo_to_profile_idx_->count(hlo)) {
+ profiling_state_.RecordCycleStart(&ir_builder_, hlo);
+ }
+ return Status::OK();
+}
+
+Status IrEmitter::Postprocess(HloInstruction* hlo) {
+ if (auto* prof_counter = GetProfileCounterFor(hlo)) {
+ profiling_state_.RecordCycleDelta(&ir_builder_, hlo, prof_counter);
+ }
+ return Status::OK();
+}
+
+llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) {
+ llvm::Value* value_for_op = GetEmittedValueFor(hlo);
+
+ llvm_ir::IrArray array(value_for_op, hlo->shape());
+ AddAliasingInformationToIrArray(*hlo, &array);
+ return array;
+}
+
+llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) {
+ auto it = emitted_value_.find(hlo);
+ if (it == emitted_value_.end()) {
+ LOG(FATAL) << "could not find emitted value for: " << hlo->ToString();
+ }
+ return it->second;
+}
+
+llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
+ return llvm_ir::ShapeToIrType(shape, &ir_builder_);
+}
+
+llvm::Argument* IrEmitter::GetResultArgument() {
+ return GetArg(compute_function_, 0);
+}
+
+llvm::Argument* IrEmitter::GetProfileCountersArgument() {
+ return hlo_to_profile_idx_ ? GetArg(compute_function_, 4) : nullptr;
+}
+
+llvm::Value* IrEmitter::GetTempBuffersArgument() {
+ return GetArg(compute_function_, 3);
+}
+
+llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
+ return GetArg(compute_function_, 1);
+}
+
+llvm::Value* IrEmitter::EmitTempBufferPointer(
+ BufferAllocation::Index temp_buf_index, const Shape& target_shape) {
+ llvm::Type* element_type = IrShapeType(target_shape);
+ // The alignment and number of bytes within the temporary buffer is determined
+ // by the maximal shape as determined by buffer assignment.
+ const BufferAllocation& allocation =
+ assignment_.GetAllocation(temp_buf_index);
+ if (allocation.is_thread_local()) {
+ // Thread-local allocations should only be assigned a single buffer.
+ CHECK_EQ(1, allocation.assigned_buffers().size());
+ const Shape& shape = allocation.assigned_buffers()[0]->shape();
+
+ llvm::AllocaInst*& tempbuf_address = thread_local_buffers_[{
+ ir_builder_.GetInsertBlock()->getParent(), temp_buf_index}];
+ if (tempbuf_address == nullptr) {
+ tempbuf_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ IrShapeType(shape),
+ tensorflow::strings::StrCat("thread_local", temp_buf_index),
+ &ir_builder_, MinimumAlignmentForShape(target_shape));
+ }
+ return ir_builder_.CreateBitCast(tempbuf_address,
+ element_type->getPointerTo());
+ }
+
+ llvm::Value* tempbuf_address_offset = llvm_ir::EmitBufferIndexingGEP(
+ GetTempBuffersArgument(), temp_buf_index, &ir_builder_);
+ llvm::LoadInst* tempbuf_address_untyped =
+ ir_builder_.CreateLoad(tempbuf_address_offset);
+ // Loading the address of a buffer is invariant of the point at which the
+ // load is executed in the program because we never reassign buffers.
+ tempbuf_address_untyped->setMetadata(
+ llvm::LLVMContext::MD_invariant_load,
+ llvm::MDNode::get(tempbuf_address_untyped->getContext(), /*MDs=*/{}));
+ llvm_ir::SetTbaaForInstruction(tempbuf_address_untyped, target_shape,
+ /*is_pointer_to=*/true);
+
+ AttachAlignmentMetadataForLoad(tempbuf_address_untyped, allocation.size());
+ AttachDereferenceableMetadataForLoad(tempbuf_address_untyped,
+ allocation.size());
+ return ir_builder_.CreateBitCast(tempbuf_address_untyped,
+ element_type->getPointerTo());
+}
+
+// Emits a function call returning a single array element. Allocates space
+// for a single element_type value, and loads it after call.
+llvm::Value* IrEmitter::EmitElementFunctionCall(
+ llvm::Function* function, const Shape& return_shape,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name) {
+ llvm::Value* return_value_buffer = EmitArrayFunctionCall(
+ function, return_shape, 1, parameter_addresses, name);
+ return ir_builder_.CreateLoad(
+ return_value_buffer,
+ llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_return_value")));
+}
+
+// Emits a core function call based on the following pseudo-code.
+//
+// char** parameter_addresses_buffer =
+// allocate buffer with a pointer for each parameter to the function
+// for each parameter index, i.e. for i = 0, ..., #parameters:
+// parameter_addresses_buffer[i] = parameter_addresses[i]
+// call function(return_value_buffer,
+// parameter_addresses_buffer,
+// temps)
+// return return_value_buffer -- address of the return value.
+void IrEmitter::EmitArrayFunctionCallInto(
+ llvm::Function* function,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value_buffer, tensorflow::StringPiece name) {
+ llvm::Value* parameter_addresses_buffer =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ ir_builder_.getInt8PtrTy(),
+ ir_builder_.getInt32(parameter_addresses.size()),
+ tensorflow::strings::StrCat(name, "_parameter_addresses"),
+ &ir_builder_);
+ for (int i = 0; i < parameter_addresses.size(); ++i) {
+ llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast(
+ parameter_addresses[i], ir_builder_.getInt8PtrTy(),
+ llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
+ "_address_as_i8ptr")));
+ llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP(
+ parameter_addresses_buffer, {ir_builder_.getInt32(i)});
+ ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses);
+ }
+
+ const auto to_int8_ptr = [this](llvm::Value* ptr) {
+ return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy());
+ };
+ std::vector<llvm::Value*> arguments{
+ to_int8_ptr(return_value_buffer),
+ to_int8_ptr(GetExecutableRunOptionsArgument()),
+ parameter_addresses_buffer, GetTempBuffersArgument()};
+ if (auto* profile_counters = GetProfileCountersArgument()) {
+ arguments.push_back(profile_counters);
+ }
+ ir_builder_.CreateCall(function, arguments);
+}
+
+llvm::Value* IrEmitter::EmitArrayFunctionCall(
+ llvm::Function* function, const Shape& return_shape, int64 element_count,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name) {
+ llvm::Value* elements =
+ llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count);
+ PrimitiveType return_type = return_shape.element_type();
+ llvm::Value* return_value_buffer =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements,
+ tensorflow::strings::StrCat(name, "_return_value_address"),
+ &ir_builder_, MinimumAlignmentForPrimitiveType(return_type));
+ EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer,
+ name);
+ return return_value_buffer;
+}
+
+StatusOr<llvm::Value*> IrEmitter::EmitTargetAddressForOp(
+ const HloInstruction* op) {
+ const Shape& target_shape = op->shape();
+ if (op == op->parent()->root_instruction()) {
+ // For the root node, we write directly to the output buffer of the
+ // function.
+ llvm::Argument* retval = GetResultArgument();
+ if (!ShapeUtil::HasZeroElements(target_shape)) {
+ llvm::AttrBuilder attr_builder;
+ attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
+ attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape));
+ retval->addAttr(llvm::AttributeSet::get(
+ retval->getContext(), retval->getArgNo() + 1, attr_builder));
+ }
+ return ir_builder_.CreateBitCast(retval,
+ IrShapeType(target_shape)->getPointerTo());
+ }
+
+ // For other nodes, we need the temporary buffer allocated for this node to
+ // write the result into.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation,
+ assignment_.GetUniqueTopLevelAllocation(op));
+ return EmitTempBufferPointer(allocation->index(), target_shape);
+}
+
+Status IrEmitter::EmitTargetElementLoop(
+ HloInstruction* target_op,
+ const llvm_ir::ElementGenerator& element_generator) {
+ VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString();
+
+ // target_address will hold the address of the target buffer we will write the
+ // result of the computation into.
+ const Shape& target_shape = target_op->shape();
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
+ EmitTargetAddressForOp(target_op));
+ VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address);
+ llvm_ir::IrArray target_array(target_address, target_shape);
+ AddAliasingInformationToIrArray(*target_op, &target_array);
+
+ TF_RETURN_IF_ERROR(
+ llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_)
+ .EmitLoop());
+ emitted_value_[target_op] = target_address;
+ return Status::OK();
+}
+
+Status IrEmitter::EmitMemcpy(const HloInstruction& source,
+ const HloInstruction& destination) {
+ llvm::Value* source_value = GetEmittedValueFor(&source);
+ llvm::Value* destination_value = GetEmittedValueFor(&destination);
+ int64 source_size = ByteSizeOf(source.shape());
+ ir_builder_.CreateMemCpy(destination_value, source_value, source_size, 1);
+ return Status::OK();
+}
+
+Status IrEmitter::ElementTypesSameAndSupported(
+ const HloInstruction& instruction,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<PrimitiveType> supported_types) {
+ for (auto operand : operands) {
+ TF_RET_CHECK(
+ ShapeUtil::SameElementType(operands[0]->shape(), operand->shape()));
+ }
+
+ TF_RET_CHECK(!operands.empty());
+ PrimitiveType primitive_type = operands[0]->shape().element_type();
+ if (std::find(supported_types.begin(), supported_types.end(),
+ primitive_type) == supported_types.end()) {
+ return Unimplemented("unsupported operand type %s in op %s",
+ PrimitiveType_Name(primitive_type).c_str(),
+ HloOpcodeString(instruction.opcode()).c_str());
+ }
+ return Status::OK();
+}
+
+Status IrEmitter::DefaultAction(HloInstruction* hlo) {
+ ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
+ for (const HloInstruction* operand : hlo->operands()) {
+ operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
+ return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_);
+ };
+ }
+ CpuElementalIrEmitter elemental_emitter(hlo_module_config_, &ir_builder_,
+ module_);
+ return EmitTargetElementLoop(
+ hlo, elemental_emitter.MakeElementGenerator(hlo, operand_to_generator));
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
new file mode 100644
index 0000000000..06415c735d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -0,0 +1,402 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
+
+#include <stddef.h>
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace cpu {
+
+// This class is the top-level API for the XLA HLO --> LLVM IR compiler. It
+// implements the DfsHloVisitor interface and emits HLO computations as LLVM IR
+// functions.
+class IrEmitter : public DfsHloVisitorWithDefault {
+ public:
+ // Create a new LLVM IR emitter.
+ //
+ // hlo_module: the HLO module we are emitting IR for.
+ // assignment: a BufferAssignment from which we know which temporary buffers
+ // are used by the HLO nodes.
+ // llvm_module: the LLVM module to emit IR into.
+ // hlo_to_profile_idx: the mapping from HLO to its index in the profiling
+ // array.
+ IrEmitter(const HloModule& hlo_module, const HloModuleConfig& module_config,
+ const BufferAssignment& assignment, llvm::Module* llvm_module,
+ const std::unordered_map<const HloInstruction*, size_t>*
+ hlo_to_profile_idx);
+ ~IrEmitter() override;
+
+ // Emit and return the given HLO computation as an LLVM IR
+ // function. function_name_prefix is the desired name of the function. If the
+ // name is not unique among already emitted functions then a suffix is
+ // appended to make the name unique. is_entry_computation indicates that this
+ // is the entry computation of the HLO module. If 'instruction_order' is given
+ // then the HLO instructions are emitted in the given order. In this case,
+ // 'instruction_order' must be a topological sort of the set of nodes
+ // accessible from the root of the computation.
+ StatusOr<llvm::Function*> EmitComputation(
+ HloComputation* computation, const string& function_name_prefix,
+ bool is_entry_computation,
+ std::vector<const HloInstruction*>* instruction_order = nullptr);
+
+ protected:
+ //
+ // The following methods implement the DfsHloVisitor interface.
+ //
+ // Default action which emits code for most operations. Operations which are
+ // special in some way are handled explicitly in HandleFoo methods.
+ Status DefaultAction(HloInstruction* hlo_instruction) override;
+
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleConstant(HloInstruction* constant,
+ const Literal& literal) override;
+ Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+ Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) override;
+ Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
+ HloInstruction* rhs, const Window& window) override;
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleInfeed(HloInstruction* infeed) override;
+ Status HandleSort(HloInstruction* sort, HloInstruction* operand) override;
+ Status HandleParameter(HloInstruction* parameter) override;
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) override;
+ Status HandleReduceWindow(HloInstruction* reduce_window,
+ HloInstruction* operand, const Window& window,
+ HloComputation* function) override;
+ Status HandleSelectAndScatter(HloInstruction* instruction) override;
+ Status HandleSend(HloInstruction* send) override;
+ Status HandleRecv(HloInstruction* recv) override;
+ Status HandlePad(HloInstruction* pad) override;
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleMap(
+ HloInstruction* map,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleCall(HloInstruction* call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) override;
+ Status HandleCustomCall(HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) override;
+ Status HandleWhile(HloInstruction* xla_while, HloInstruction* init,
+ HloComputation* condition, HloComputation* body) override;
+ Status FinishVisit(HloInstruction* root) override;
+
+ Status Preprocess(HloInstruction* hlo) override;
+ Status Postprocess(HloInstruction* visited) override;
+
+ private:
+ // Private helper to initialize an IR function for the computation.
+ void InitializeIrFunction(const string& function_name,
+ bool is_entry_computation);
+
+ // Convenience function to generate a GEP into the profile counter parameter
+ // which would correspond to the index for a given HLO.
+ llvm::Value* GetProfileCounterFor(const HloInstruction* hlo);
+
+ // Convenience function to get the IR Value emitted previously for the given
+ // hlo. Make sure to call it only when you're certain a value *was* emitted -
+ // if not found, this will log a fatal error.
+ llvm::Value* GetEmittedValueFor(const HloInstruction* hlo);
+
+ // Convenience function to get an IrArray representing the given hlo.
+ llvm_ir::IrArray GetIrArrayForOp(const HloInstruction* hlo);
+
+ // Augments IrArray with aliasing information.
+ void AddAliasingInformationToIrArray(const HloInstruction& hlo,
+ llvm_ir::IrArray* array) {
+ alias_analysis_.AddAliasingInformationToIrArray(hlo, array);
+ }
+
+ // Convenience function to get the IR type matching the given shape.
+ llvm::Type* IrShapeType(const Shape& shape);
+
+ // Get the llvm::Value* that represents the "retval" argument of the
+ // computation function being emitted by this emitter.
+ llvm::Argument* GetResultArgument();
+
+ // Get the llvm::Value* that represents the "prof_counters" argument of the
+ // computation function being emitted by this emitter.
+ llvm::Argument* GetProfileCountersArgument();
+
+ // Get the xla::ExecutableRunOptions that represents the "run_options"
+ // argument of the computation function being emitted by this emitter.
+ llvm::Value* GetExecutableRunOptionsArgument();
+
+ // Get the llvm::Value* that represents the "temps" argument of the
+ // computation function being emitted by this emitter.
+ llvm::Value* GetTempBuffersArgument();
+
+ // Emits code that computes the address of the given temporary buffer to the
+ // function. target_shape is the shape of this temporary buffer.
+ // The returned Value's type is a pointer to element_type.
+ llvm::Value* EmitTempBufferPointer(BufferAllocation::Index temp_buf_index,
+ const Shape& target_shape);
+
+ // Emits a function into the current module. This can be used for
+ // computations embedded inside other computations, such as the
+ // function that a map operation applies.
+ StatusOr<llvm::Function*> EmitFunction(
+ HloComputation* function, // The function to emit.
+ tensorflow::StringPiece
+ function_name_suffix); // Used for LLVM IR register names.
+
+ // Methods that emit a function call.
+ // Parameters:
+ // function - The LLVM function to call.
+ // return_shape - The return shape of the HLO computation that was used to
+ // make the function. Not the same as the return type of the function
+ // in LLVM, since we use output parameters for the return type.
+ // element_count - number of elements to return (array form only).
+ // parameter_addresses - pointers to be passed to the function as
+ // parameters.
+ // name - used for LLVM IR register names.
+
+ // Emits a function call, returning a scalar, often an element of a larger
+ // array. Returns a Value for the scalar element returned by the function.
+ llvm::Value* EmitElementFunctionCall(
+ llvm::Function* function, const Shape& return_shape,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name);
+
+ // Array function call emitter. Stores the function's result into a supplied
+ // buffer.
+ // Parameters:
+ // function - The LLVM function to call.
+ // parameter_addresses - pointers to be passed to the function as
+ // parameters.
+ // return_value - pointer to a buffer where the call result is stored.
+
+ void EmitArrayFunctionCallInto(
+ llvm::Function* function,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ llvm::Value* return_value, tensorflow::StringPiece name);
+
+ // Array function call emitter. Returns a Value for the function's return
+ // value buffer address. The return value buffer is alloca'ed by this
+ // function.
+ llvm::Value* EmitArrayFunctionCall(
+ llvm::Function* function, const Shape& return_shape, int64 element_count,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_addresses,
+ tensorflow::StringPiece name);
+
+ // Verifies that the element types of all of the given operand instructions
+ // match and are of one of the given supported types.
+ Status ElementTypesSameAndSupported(
+ const HloInstruction& instruction,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
+ tensorflow::gtl::ArraySlice<PrimitiveType> supported_types);
+
+ // Emit IR to perform a computation for every element in the given target op.
+ // This produces a series of nested loops (one for each dimension of the op's
+ // shape). The body of the inner-most loop is provided by the body_emitter
+ // function.
+ //
+ // TODO(jingyue): target_op should be a `const HloInstruction*`.
+ Status EmitTargetElementLoop(
+ HloInstruction* target_op,
+ const llvm_ir::ElementGenerator& element_generator);
+
+ // Emits a memcpy from the source instruction's result value to the
+ // destination's. Both source and destination must have an entry in the
+ // emitted_value_ table.
+ Status EmitMemcpy(const HloInstruction& source,
+ const HloInstruction& destination);
+
+ // Emit IR to compute the target address of the buffer for the given op.
+ // The returned Value is a pointer to a IR type that represents the op's
+ // element type.
+ StatusOr<llvm::Value*> EmitTargetAddressForOp(const HloInstruction* op);
+
+ // Structurizes "array_elements" into an MD array that represents "shape".
+ // This is a recursive function, and "dimension_index" indicates the index of
+ // the current dimension that the function is considering (0 means the
+ // most-minor dimension).
+ llvm::Constant* CreateInitializerForConstantArray(
+ const std::vector<llvm::Constant*>& array_elements, const Shape& shape,
+ int64 dimension_index);
+
+ // Name of the computation entry function. This function serves as the
+ // top-level "main" of the computation and will be invoked by the JIT.
+ string entry_function_name_;
+
+ // Assignment of the temporary buffers needed by the computation and their
+ // shape information.
+ const BufferAssignment& assignment_;
+
+ // The LLVM module into which IR will be emitted.
+ llvm::Module* module_;
+
+ // The target architecture.
+ llvm::Triple::ArchType arch_type_;
+
+ // Used to produce unique names for generated functions.
+ NameUniquer name_uniquer_;
+
+ // Map containing all previously emitted computations.
+ std::map<HloComputation*, llvm::Function*> emitted_functions_;
+
+ // Map containing all previously emitted thread-local temporary buffers.
+ std::map<std::pair<llvm::Function*, BufferAllocation::Index>,
+ llvm::AllocaInst*>
+ thread_local_buffers_;
+
+ // The following fields track the IR emission state. According to LLVM memory
+ // management rules, their memory is owned by the module.
+ llvm::Function* compute_function_;
+ llvm::IRBuilder<> ir_builder_;
+
+ // Maps HLOs to their index into the profile counter array.
+ const std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
+
+ // Maps HLOs to Values emitted for them.
+ std::unordered_map<const HloInstruction*, llvm::Value*> emitted_value_;
+
+ llvm_ir::AliasAnalysis alias_analysis_;
+
+ // This struct contains all the state needed to emit instructions for
+ // profiling a computation.
+ class ProfilingState {
+ public:
+ ProfilingState()
+ : is_entry_computation_(false),
+ use_rdtscp_(false),
+ prof_counters_(nullptr) {}
+ ProfilingState(bool is_entry_computation, bool use_rdtscp,
+ llvm::Argument* prof_counters)
+ : is_entry_computation_(is_entry_computation),
+ use_rdtscp_(use_rdtscp),
+ prof_counters_(prof_counters) {}
+
+ // Record the cycle counter before an HLO executes.
+ void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo);
+ // Record the number of cycles it took for an HLO to execute.
+ void RecordCycleDelta(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo,
+ llvm::Value* prof_counter);
+ // Record the number of cycles it took for the entire computation to
+ // execute.
+ void RecordCompleteComputation(llvm::IRBuilder<>* ir_builder,
+ llvm::Value* prof_counter);
+
+ // Convenience function to generate a call to an intrinsic which reads the
+ // CPU cycle counter.
+ llvm::Value* ReadCycleCounter(llvm::IRBuilder<>* ir_builder);
+
+ // Store the cycle counter delta to the per-HLO profile counter.
+ void UpdateProfileCounter(llvm::IRBuilder<>* ir_builder,
+ llvm::Value* prof_counter, llvm::Value* cycle_end,
+ llvm::Value* cycle_start);
+
+ private:
+ // Is this IrEmitter for a top-level computation?
+ bool is_entry_computation_;
+
+ // Should we use the x86-specific rdtscp or the generic readcyclecounter
+ // intrinsic?
+ bool use_rdtscp_;
+
+ // The argument which corresponds to the profile counter buffer.
+ llvm::Argument* prof_counters_;
+
+ // The first read cycle counter in the program.
+ llvm::Value* first_read_cycle_start_ = nullptr;
+
+ // The last read cycle counter in the program.
+ llvm::Value* last_read_cycle_end_ = nullptr;
+
+ // An alloca used to hold the output of the aux value returned by the rdtscp
+ // intrinsic.
+ llvm::Value* aux_i8ptr_ = nullptr;
+
+ // Maps HLOs to the value the cycle counter contained right before the HLO
+ // began to execute.
+ std::unordered_map<const HloInstruction*, llvm::Value*> cycle_starts_;
+ };
+
+ ProfilingState profiling_state_;
+
+ // Given a load instruction and a shape or buffer size, annotate the load's
+ // result with the alignment required by the shape or size.
+ void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, const Shape& shape);
+ void AttachAlignmentMetadataForLoad(llvm::LoadInst* load, int64 buffer_size);
+
+ // Given a load instruction and a shape or buffer size, annotate the load's
+ // result with the dereferenceable bytes required by the shape / buffer size.
+ void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ const Shape& shape);
+ void AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ int64 buffer_size);
+
+ // Calculate the alignment of a buffer allocated for a given shape.
+ int MinimumAlignmentForShape(const Shape& shape);
+
+ // Calculate the alignment of a buffer allocated for a given primitive type.
+ int MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type);
+
+ // Calculate the alignment of a buffer with a particular size.
+ int MinimumAlignmentForBufferSize(int64 buffer_size);
+
+ // Returns the number of bytes within the shape.
+ int64 ByteSizeOf(const Shape& shape) const;
+
+ const HloModuleConfig& hlo_module_config_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter);
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
new file mode 100644
index 0000000000..136a8c9641
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h"
+
+#include <numeric>
+
+#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace cpu {
+
+Status CpuLayoutAssignment::AddBackendConstraints(
+ LayoutConstraints* constraints) {
+ auto row_major_shape = [](const Shape& old_shape) {
+ Shape new_shape(old_shape);
+ std::vector<int64> dimension_order(new_shape.dimensions_size());
+ std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
+ *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
+ return new_shape;
+ };
+ const HloComputation* computation = constraints->computation();
+ for (auto& instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kConvolution &&
+ PotentiallyImplementedAsEigenConvolution(*instruction)) {
+ const HloInstruction* convolution = instruction.get();
+ const HloInstruction* lhs_instruction = convolution->operand(0);
+ const HloInstruction* rhs_instruction = convolution->operand(1);
+
+ // In order to implement `convolution` with Eigen convolution, the layouts
+ // of the input, filter, and output need to be row-major.
+ //
+ // These constraints are not hard constraints. Ideally, we should decide
+ // which layouts to choose according to some cost model.
+ Shape output_shape(row_major_shape(convolution->shape()));
+ Shape input_shape(row_major_shape(lhs_instruction->shape()));
+ Shape filter_shape(row_major_shape(rhs_instruction->shape()));
+
+ // Set layouts of the instructions' shapes.
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(input_shape, convolution, 0));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(filter_shape, convolution, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(output_shape, convolution));
+ } else if (PotentiallyImplementedAsEigenDot(*instruction)) {
+ const HloInstruction* dot = instruction.get();
+ const HloInstruction* lhs_instruction = dot->operand(0);
+ const HloInstruction* rhs_instruction = dot->operand(1);
+
+ // In order to implement `dot` with Eigen dot, the layouts of the lhs,
+ // rhs, and output need to be row-major.
+ //
+ // These constraints are not hard constraints. Ideally, we should decide
+ // which layouts to choose according to some cost model.
+ Shape output_shape(row_major_shape(dot->shape()));
+ Shape lhs_shape(row_major_shape(lhs_instruction->shape()));
+ Shape rhs_shape(row_major_shape(rhs_instruction->shape()));
+
+ // Set layouts of the instructions' shapes.
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1));
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot));
+ } else {
+ for (int64 operand_no = 0; operand_no < instruction->operand_count();
+ ++operand_no) {
+ // Skip operands which already have a constraint.
+ if (constraints->OperandLayout(instruction.get(), operand_no) !=
+ nullptr) {
+ continue;
+ }
+ // Skip over forwarded operands.
+ if (constraints->OperandBufferForwarded(instruction.get(),
+ operand_no)) {
+ continue;
+ }
+ Shape operand_shape(
+ row_major_shape(instruction->operand(operand_no)->shape()));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ operand_shape, instruction.get(), operand_no));
+ }
+ // Skip over the root instruction for the top-level computation.
+ if (computation->parent()->entry_computation() == computation &&
+ computation->root_instruction() == instruction.get()) {
+ continue;
+ }
+ // Skip instructions which don't produce array shapes (tuples, opaque,
+ // etc.).
+ if (!ShapeUtil::IsArray(instruction->shape())) {
+ continue;
+ }
+ tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers =
+ constraints->points_to_analysis()
+ .GetPointsToSet(instruction.get())
+ .element({});
+ // Only force the layout if the instruction hasn't been otherwise assigned
+ // one or has ambiguous aliasing properties.
+ if (buffers.size() == 1 &&
+ buffers[0]->instruction() == instruction.get() &&
+ constraints->BufferLayout(*buffers[0]) == nullptr) {
+ Shape output_shape(row_major_shape(instruction->shape()));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(output_shape, instruction.get()));
+ }
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/layout_assignment.h
new file mode 100644
index 0000000000..4fd8d68dd6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_
+
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+namespace cpu {
+
+// CPU-specific layout assignment pass which preassigns layouts to satisfy
+// layout constraints for operands and results of library calls.
+class CpuLayoutAssignment : public LayoutAssignment {
+ public:
+ explicit CpuLayoutAssignment(ComputationLayout* entry_computation_layout)
+ : LayoutAssignment(entry_computation_layout) {}
+ ~CpuLayoutAssignment() override {}
+
+ protected:
+ Status AddBackendConstraints(LayoutConstraints* constraints) override;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
new file mode 100644
index 0000000000..268a36a660
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
@@ -0,0 +1,365 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h"
+
+#include <stdint.h>
+#include <algorithm>
+#include <deque>
+#include <iterator>
+#include <list>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace cpu {
+
+ParallelCpuExecutable::ParallelCpuExecutable(
+ std::unique_ptr<SimpleOrcJIT> jit,
+ std::unique_ptr<BufferAssignment> assignment,
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ std::unique_ptr<std::map<HloInstruction*, string>> function_names,
+ std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
+ std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
+ aligned_constants)
+ : Executable(std::move(hlo_module), std::move(module_config)),
+ jit_(std::move(jit)),
+ assignment_(std::move(assignment)),
+ functions_names_(std::move(function_names)),
+ hlo_to_profile_idx_(std::move(hlo_to_profile_idx)),
+ aligned_constants_(std::move(aligned_constants)) {}
+
+// Type of the computation function we expect in the JIT.
+using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
+ uint64*);
+
+// Given a pointer to an output buffer (following the CPU JIT calling
+// conventions), mark addresses that are "live". The initial pointer itself is
+// trivially live. If the shape of the buffer is a tuple, this analysis looks
+// into the tuple's elements and marks them live as well (since tuples keep
+// pointers to buffers) and also works recursively.
+// address is an in-memory buffer address that contains some runtime XLA object.
+// shape is its shape. marked_addresses is the set of live addresses to
+// populate.
+static void MarkLiveAddressesInOutput(
+ const void* address, const Shape& shape,
+ std::unordered_set<const void*>* marked_addresses) {
+ marked_addresses->insert(address);
+ const uintptr_t* address_buffer = static_cast<const uintptr_t*>(address);
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ const uintptr_t* element_address = address_buffer + i;
+ const void* element = reinterpret_cast<const void*>(*element_address);
+ MarkLiveAddressesInOutput(
+ element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses);
+ }
+ }
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase>
+ParallelCpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
+ if (!arguments.empty()) {
+ VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
+ }
+
+ // Allocate the temporary buffers required for the computation.
+ se::StreamExecutor* stream_executor = stream->parent();
+ int device_ordinal = stream_executor->device_ordinal();
+ int64 buffer_count = assignment_->Allocations().size();
+ VLOG(3) << "temp buffer count: " << buffer_count;
+
+ std::vector<se::DeviceMemoryBase> device_allocations;
+ for (BufferAllocation::Index i = 0; i < buffer_count; ++i) {
+ auto& allocation = assignment_->GetAllocation(i);
+ if (allocation.is_entry_computation_parameter()) {
+ // Buffers do not need to be allocated for parameters.
+ device_allocations.push_back(se::DeviceMemoryBase(nullptr));
+ continue;
+ }
+
+ if (allocation.is_thread_local()) {
+ // Buffers do not need to be allocated for thread-local temporaries.
+ device_allocations.push_back(se::DeviceMemoryBase(nullptr));
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ se::DeviceMemoryBase device_allocation,
+ memory_allocator->Allocate(device_ordinal, allocation.size()));
+
+ // TODO(eliben): refactor this into buffer_assignment
+ if (VLOG_IS_ON(3)) {
+ VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size()
+ << " bytes for allocation #" << i << " ["
+ << device_allocation.opaque() << "]";
+ std::vector<string> parts;
+ for (const LogicalBuffer* buffer : allocation.assigned_buffers()) {
+ parts.push_back(buffer->ToString());
+ }
+ VLOG(3) << " " << tensorflow::str_util::Join(parts, ", ");
+ }
+
+ device_allocations.push_back(device_allocation);
+ // Since the output buffer and all the temporary buffers were written into
+ // by the JITed code, msan has no way of knowing their memory was
+ // initialized. Mark them initialized so that msan doesn't flag loads from
+ // these buffers.
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(),
+ allocation.size());
+ }
+
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
+ assignment_->GetUniqueTopLevelOutputAllocation());
+ BufferAllocation::Index result_index = result_allocation->index();
+ VLOG(3) << "result index: " << result_index;
+
+ // Allocate profiling counters for each hlo instruction that we would like to
+ // profile. Allocate an additional profile counter for the entire
+ // computation.
+ std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
+
+ std::vector<void*> buffer_pointers;
+ for (auto& device_allocation : device_allocations) {
+ buffer_pointers.push_back(device_allocation.opaque());
+ }
+
+ // Resolve functions for all the HLO instructions ahead of time.
+ std::map<HloInstruction*, ComputeFunctionType> functions;
+ for (auto& entry : *functions_names_) {
+ tensorflow::mutex_lock lock(jit_mutex_);
+ HloInstruction* instruction = entry.first;
+ llvm::JITSymbol sym = jit_->FindSymbol(entry.second);
+ TF_RET_CHECK(sym);
+ InsertOrDie(&functions, instruction,
+ reinterpret_cast<ComputeFunctionType>(sym.getAddress()));
+ }
+
+ // Map containing pointers to result buffers for each instruction.
+ std::map<HloInstruction*, const void*> results;
+
+ uint64 start_micros = tensorflow::Env::Default()->NowMicros();
+
+ std::list<HloInstruction*> pending;
+
+ // Call the function for each HLO instruction in topological order.
+ for (auto* instruction :
+ module().entry_computation()->MakeInstructionPostOrder()) {
+ // Parameters and constants have no functions associated with them. Instead
+ // just copy the existing buffer into the map containing instruction
+ // results..
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ InsertOrDie(&results, instruction,
+ arguments[instruction->parameter_number()].opaque());
+ } else if (instruction->opcode() == HloOpcode::kConstant) {
+ unsigned char* aligned_data =
+ FindOrDie(aligned_constants_, instruction).get();
+ InsertOrDie(&results, instruction, aligned_data);
+ } else {
+ TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall);
+ pending.push_back(instruction);
+ }
+ }
+
+ auto* temps_array = buffer_pointers.data();
+ auto* profile_counters_array = profile_counters.data();
+ auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool());
+ tensorflow::mutex completion_queue_lock;
+ tensorflow::condition_variable completion_queue_cv;
+ std::deque<HloInstruction*> completion_queue;
+ int64 instructions_in_flight = 0;
+ while (!pending.empty() || instructions_in_flight > 0) {
+ auto pending_it = pending.begin();
+ while (pending_it != pending.end()) {
+ HloInstruction* instruction = *pending_it;
+ // Skip pending instructions whose operands aren't ready.
+ if (std::any_of(instruction->operands().begin(),
+ instruction->operands().end(),
+ [&](HloInstruction* operand) {
+ return !ContainsKey(results, operand);
+ })) {
+ ++pending_it;
+ continue;
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ const BufferAllocation* result_allocation,
+ assignment_->GetUniqueTopLevelAllocation(instruction));
+
+ void* result_buffer = buffer_pointers[result_allocation->index()];
+ // We cannot use a move-only RAII type like std::unique_ptr because the
+ // list of operands is allocated on the main thread and transferred to the
+ // worker via the lambda passed to enqueue_function. In order for the
+ // lambda to take ownership, we would need to use generalized lambda
+ // capture which is a feature new to C++14.
+ auto operand_buffers = new const void*[instruction->operand_count()];
+ std::transform(instruction->operands().begin(),
+ instruction->operands().end(), operand_buffers,
+ [&results](HloInstruction* operand) {
+ return FindOrDie(results, operand);
+ });
+ auto function = FindOrDie(functions, instruction);
+ // The thread pool entry takes ownership of |operand_buffers|.
+ thread_pool->Schedule([instruction, &completion_queue,
+ &completion_queue_lock, &completion_queue_cv,
+ result_buffer, run_options, operand_buffers,
+ temps_array, profile_counters_array, function] {
+ function(result_buffer, run_options, operand_buffers, temps_array,
+ profile_counters_array);
+ delete[] operand_buffers;
+ // Push the completed HLO instruction on the queue, the main thread
+ // will pop it off and potentially launch more work which uses the
+ // result.
+ {
+ tensorflow::mutex_lock l(completion_queue_lock);
+ completion_queue.push_back(instruction);
+ completion_queue_cv.notify_all();
+ }
+ });
+
+ ++instructions_in_flight;
+ pending_it = pending.erase(pending_it);
+ }
+ // Wait for a completed HLO instruction to be present in the queue. We will
+ // pop it out of the queue and make the result available to its users.
+ HloInstruction* instruction;
+ do {
+ tensorflow::mutex_lock l(completion_queue_lock);
+ if (completion_queue.empty()) {
+ completion_queue_cv.wait(l);
+ }
+ if (!completion_queue.empty()) {
+ instruction = completion_queue.front();
+ completion_queue.pop_front();
+ break;
+ }
+ } while (1);
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* result_allocation,
+ assignment_->GetUniqueTopLevelAllocation(instruction));
+ void* result_buffer = buffer_pointers[result_allocation->index()];
+ InsertOrDie(&results, instruction, result_buffer);
+ --instructions_in_flight;
+ }
+ uint64 end_micros = tensorflow::Env::Default()->NowMicros();
+
+ {
+ tensorflow::mutex_lock lock(mutex_);
+ double nanoseconds = (end_micros - start_micros) * 1000.0;
+ execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
+ // The last profile counter is used for the computation as a whole.
+ execution_profile_.set_compute_cycle_count(profile_counters.back());
+ }
+ if (hlo_execution_profile != nullptr) {
+ hlo_execution_profile->set_total_cycles_executed(profile_counters.back());
+
+ for (auto hlo_prof_idx : hlo_to_profile_idx_) {
+ const HloInstruction* hlo = hlo_prof_idx.first;
+ uint64 cycles_taken = profile_counters[hlo_prof_idx.second];
+ hlo_execution_profile->AddProfileResult(hlo, cycles_taken);
+ }
+ }
+
+ // Mark the buffers that are actually live (used in the output) when the
+ // computation finishes executing.
+ std::unordered_set<const void*> marked_addresses;
+ MarkLiveAddressesInOutput(device_allocations[result_index].opaque(),
+ result_shape(), &marked_addresses);
+
+ VLOG(3) << "Live addresses in output marking found "
+ << marked_addresses.size() << " addresses:\n"
+ << tensorflow::str_util::Join(
+ marked_addresses, ", ", [](string* out, const void* address) {
+ tensorflow::strings::StrAppend(
+ out, tensorflow::strings::Printf("%p", address));
+ });
+
+ // Computation is done - deallocate temp buffers. Keep those marked
+ // live because they are referenced by the output of the computation
+ // and are needed by the service. They will be deallocated by the
+ // service.
+ for (auto i = 0; i < device_allocations.size(); ++i) {
+ auto alloc = device_allocations[i];
+ if (marked_addresses.count(alloc.opaque()) == 0 &&
+ alloc.opaque() != nullptr) {
+ VLOG(3) << "ParallelCpuExecutable deallocating buffer #" << i << " ["
+ << alloc.opaque() << "]";
+ TF_RETURN_IF_ERROR(memory_allocator->Deallocate(device_ordinal, &alloc));
+ }
+ }
+
+ return device_allocations[result_index];
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ return Unimplemented(
+ "ParallelCpuExecutable not supported yet with LocalService execution");
+}
+
+Status ParallelCpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) {
+ return Unimplemented(
+ "preallocated result buffer not supported with ParallelCpuExecutable");
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase>
+ParallelCpuExecutable::ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ // TODO(b/30671675): Implement asynchronous execution mode.
+ return Unimplemented(
+ "Asynchronous execution on stream is not yet supported on CPU.");
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
new file mode 100644
index 0000000000..51ec9e5a74
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
@@ -0,0 +1,124 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_
+
+#include <stddef.h>
+#include <map>
+#include <memory>
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+namespace cpu {
+
+// CPU-targeting parallel implementation of the XLA Executable interface.
+//
+// Wraps a JIT-ed object that can be executed "on device". We JIT for the host
+// architecture, so JIT-ed code and host code share the same ABI.
+class ParallelCpuExecutable : public Executable {
+ public:
+ ParallelCpuExecutable(
+ std::unique_ptr<SimpleOrcJIT> jit,
+ std::unique_ptr<BufferAssignment> assignment,
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ std::unique_ptr<std::map<HloInstruction*, string>> instruction_functions,
+ std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx,
+ std::unordered_map<const HloInstruction*,
+ std::unique_ptr<unsigned char[]>>
+ aligned_constants);
+ ~ParallelCpuExecutable() override {}
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ Status ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments) override;
+
+ // This should be called after set_ir_module_string.
+ const string& ir_module_string() const { return ir_module_string_; }
+
+ void set_ir_module_string(const string& ir_module_string) {
+ ir_module_string_ = ir_module_string;
+ }
+
+ private:
+ // The JIT containing compiled modules.
+ tensorflow::mutex jit_mutex_;
+ std::unique_ptr<SimpleOrcJIT> jit_ GUARDED_BY(jit_mutex_);
+
+ // Buffer assignment for the buffers we need to allocate.
+ std::unique_ptr<BufferAssignment> assignment_;
+
+ // The LLVM IR, in string format, of the unoptimized module generated for this
+ // ParallelCpuExecutable. We save a string instead of an llvm::Module* because
+ // leaving llvm::Module* in a singleton can cause the heap checker to emit
+ // false positives.
+ string ir_module_string_;
+
+ // Map containing the JITted function names for each HLO instruction.
+ std::unique_ptr<std::map<HloInstruction*, string>> functions_names_;
+
+ // Maps HLOs to their index into the profile counter array.
+ const std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx_;
+
+ // Map from HLO Constant instructions to a pointer to their literal data.
+ // The data stored in the protocol buffer might be insufficiently aligned,
+ // we create a sufficiently aligned copy and store it in this map.
+ std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
+ aligned_constants_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ParallelCpuExecutable);
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc
new file mode 100644
index 0000000000..c2f64eb27a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int64;
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32(
+ const void* run_options_ptr, float* out, float* lhs, float* rhs,
+ int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels,
+ int64 kernel_rows, int64 kernel_cols, int64 kernel_channels,
+ int64 kernel_filters, int64 output_rows, int64 output_cols,
+ int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
+ int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
+ int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+ tensorflow::xla::EigenConvF32Impl(
+ *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch,
+ input_rows, input_cols, input_channels, kernel_rows, kernel_cols,
+ kernel_channels, kernel_filters, output_rows, output_cols, row_stride,
+ col_stride, padding_top, padding_bottom, padding_left, padding_right,
+ lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h
new file mode 100644
index 0000000000..05ae094691
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+extern void __xla_cpu_runtime_EigenConvF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 input_batch,
+ tensorflow::int64 input_rows, tensorflow::int64 input_cols,
+ tensorflow::int64 input_channels, tensorflow::int64 kernel_rows,
+ tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels,
+ tensorflow::int64 kernel_filters, tensorflow::int64 output_rows,
+ tensorflow::int64 output_cols, tensorflow::int64 row_stride,
+ tensorflow::int64 col_stride, tensorflow::int64 padding_top,
+ tensorflow::int64 padding_bottom, tensorflow::int64 padding_left,
+ tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation,
+ tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation,
+ tensorflow::int64 rhs_col_dilation);
+
+} // extern "C"
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h
new file mode 100644
index 0000000000..02f45fee0f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
+#include "tensorflow/core/platform/types.h"
+
+// 'tensorflow' namespace is used so that int64 and other types don't require
+// qualification.
+namespace tensorflow {
+namespace xla {
+
+template <typename EigenDevice>
+void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs,
+ float* rhs, int64 input_batch, int64 input_rows,
+ int64 input_cols, int64 input_channels, int64 kernel_rows,
+ int64 kernel_cols, int64 kernel_channels,
+ int64 kernel_filters, int64 output_rows,
+ int64 output_cols, int64 row_stride, int64 col_stride,
+ int64 padding_top, int64 padding_bottom,
+ int64 padding_left, int64 padding_right,
+ int64 lhs_row_dilation, int64 lhs_col_dilation,
+ int64 rhs_row_dilation, int64 rhs_col_dilation) {
+ const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
+ Eigen::Aligned>
+ input(lhs, input_batch, input_rows, input_cols, input_channels);
+
+ const Eigen::TensorMap<Eigen::Tensor<const float, 4, Eigen::RowMajor>,
+ Eigen::Aligned>
+ kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters);
+
+ Eigen::TensorMap<Eigen::Tensor<float, 4, Eigen::RowMajor>, Eigen::Aligned>
+ output(out, input_batch, output_rows, output_cols, kernel_filters);
+
+ Eigen::array<Eigen::IndexPair<int64>, 1> contract_dims;
+ contract_dims[0] = Eigen::IndexPair<int64>(1, 0);
+
+ // Molds the output of the patch extraction code into a 2d tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ Eigen::DSizes<int64, 2> pre_contract_dims;
+ pre_contract_dims[0] = output_cols * output_rows * input_batch;
+ pre_contract_dims[1] = kernel_channels * kernel_cols * kernel_rows;
+
+ // Molds the output of the contraction into the shape expected by the user:
+ Eigen::DSizes<int64, 4> post_contract_dims;
+ post_contract_dims[0] = input_batch;
+ post_contract_dims[1] = output_rows;
+ post_contract_dims[2] = output_cols;
+ post_contract_dims[3] = kernel_filters;
+
+ Eigen::DSizes<int64, 2> kernel_dims;
+ kernel_dims[0] = kernel_channels * kernel_cols * kernel_rows;
+ kernel_dims[1] = kernel_filters;
+
+ // The row and column dimensions must be flipped when passed to Eigen.
+ output.device(device) =
+ input
+ .extract_image_patches(kernel_cols, kernel_rows, col_stride,
+ row_stride, rhs_col_dilation, rhs_row_dilation,
+ lhs_col_dilation, lhs_row_dilation,
+ padding_left, padding_right, padding_top,
+ padding_bottom, 0.0f)
+ .reshape(pre_contract_dims)
+ .contract(kernel.reshape(kernel_dims), contract_dims)
+ .reshape(post_contract_dims);
+}
+
+} // namespace xla
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_IMPL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
new file mode 100644
index 0000000000..677080a862
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.cc
@@ -0,0 +1,81 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+namespace {
+
+template <typename T>
+void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+ const xla::ExecutableRunOptions* run_options =
+ static_cast<const xla::ExecutableRunOptions*>(run_options_ptr);
+
+ int64 lhs_rows = m;
+ int64 lhs_cols = k;
+ if (transpose_lhs) {
+ std::swap(lhs_rows, lhs_cols);
+ }
+
+ int64 rhs_rows = k;
+ int64 rhs_cols = n;
+ if (transpose_rhs) {
+ std::swap(rhs_rows, rhs_cols);
+ }
+
+ const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> A(
+ lhs, lhs_rows, lhs_cols);
+ const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> B(
+ rhs, rhs_rows, rhs_cols);
+ Eigen::TensorMap<Eigen::Tensor<T, 2>, Eigen::Aligned> C(out, m, n);
+
+ typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
+ int lhs_contract_dim = transpose_lhs ? 0 : 1;
+ int rhs_contract_dim = transpose_rhs ? 1 : 0;
+ const Eigen::array<DimPair, 1> dims(
+ DimPair(lhs_contract_dim, rhs_contract_dim));
+
+ // Matrix multiply is a special case of the "contract" operation where
+ // the contraction is performed along dimension 1 of the lhs and dimension
+ // 0 of the rhs.
+ C.device(*run_options->intra_op_thread_pool()) = A.contract(B, dims);
+}
+
+} // namespace
+
+void __xla_cpu_runtime_EigenMatMulF32(const void* run_options_ptr, float* out,
+ float* lhs, float* rhs, int64 m, int64 n,
+ int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+}
+
+void __xla_cpu_runtime_EigenMatMulF64(const void* run_options_ptr, double* out,
+ double* lhs, double* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs,
+ int32 transpose_rhs) {
+ MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h
new file mode 100644
index 0000000000..fdb644651d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_matmul.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// Performs a multi-threaded matrix multiplication using Eigen. 'lhs' and 'rhs'
+// are pointers to buffers containing input matrices in column-major
+// order. 'out' is a pointer to a buffer sufficiently large to hold the result
+// of the operation. Following standard nomenclature: lhs is m x k,
+// rhs is k x n, and out is m x n.
+extern void __xla_cpu_runtime_EigenMatMulF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+extern void __xla_cpu_runtime_EigenMatMulF64(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+ double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+} // extern "C"
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATMUL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc
new file mode 100644
index 0000000000..d0b0e11ac0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int64;
+
+TF_ATTRIBUTE_NO_SANITIZE_MEMORY void
+__xla_cpu_runtime_EigenSingleThreadedConvF32(
+ const void* run_options_ptr, float* out, float* lhs, float* rhs,
+ int64 input_batch, int64 input_rows, int64 input_cols, int64 input_channels,
+ int64 kernel_rows, int64 kernel_cols, int64 kernel_channels,
+ int64 kernel_filters, int64 output_rows, int64 output_cols,
+ int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom,
+ int64 padding_left, int64 padding_right, int64 lhs_row_dilation,
+ int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) {
+ tensorflow::xla::EigenConvF32Impl(
+ Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows,
+ input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels,
+ kernel_filters, output_rows, output_cols, row_stride, col_stride,
+ padding_top, padding_bottom, padding_left, padding_right,
+ lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h
new file mode 100644
index 0000000000..8ae1a42149
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+extern void __xla_cpu_runtime_EigenSingleThreadedConvF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 input_batch,
+ tensorflow::int64 input_rows, tensorflow::int64 input_cols,
+ tensorflow::int64 input_channels, tensorflow::int64 kernel_rows,
+ tensorflow::int64 kernel_cols, tensorflow::int64 kernel_channels,
+ tensorflow::int64 kernel_filters, tensorflow::int64 output_rows,
+ tensorflow::int64 output_cols, tensorflow::int64 row_stride,
+ tensorflow::int64 col_stride, tensorflow::int64 padding_top,
+ tensorflow::int64 padding_bottom, tensorflow::int64 padding_left,
+ tensorflow::int64 padding_right, tensorflow::int64 lhs_row_dilation,
+ tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation,
+ tensorflow::int64 rhs_col_dilation);
+
+} // extern "C"
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_CONV2D_H_
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
new file mode 100644
index 0000000000..384a978873
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::int32;
+using tensorflow::int64;
+
+namespace {
+
+template <typename T>
+void MatMul(const void* run_options_ptr, T* out, T* lhs, T* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+ int64 lhs_rows = m;
+ int64 lhs_cols = k;
+ if (transpose_lhs) {
+ std::swap(lhs_rows, lhs_cols);
+ }
+
+ int64 rhs_rows = k;
+ int64 rhs_cols = n;
+ if (transpose_rhs) {
+ std::swap(rhs_rows, rhs_cols);
+ }
+
+ const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> A(
+ lhs, lhs_rows, lhs_cols);
+ const Eigen::TensorMap<Eigen::Tensor<const T, 2>, Eigen::Aligned> B(
+ rhs, rhs_rows, rhs_cols);
+ Eigen::TensorMap<Eigen::Tensor<T, 2>, Eigen::Aligned> C(out, m, n);
+
+ typedef typename Eigen::Tensor<T, 2>::DimensionPair DimPair;
+ int lhs_contract_dim = transpose_lhs ? 0 : 1;
+ int rhs_contract_dim = transpose_rhs ? 1 : 0;
+ const Eigen::array<DimPair, 1> dims(
+ DimPair(lhs_contract_dim, rhs_contract_dim));
+
+ // Matrix multiply is a special case of the "contract" operation where
+ // the contraction is performed along dimension 1 of the lhs and dimension
+ // 0 of the rhs.
+ C = A.contract(B, dims);
+}
+
+} // namespace
+
+void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
+ const void* run_options_ptr, float* out, float* lhs, float* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+ MatMul<float>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+}
+
+void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
+ const void* run_options_ptr, double* out, double* lhs, double* rhs, int64 m,
+ int64 n, int64 k, int32 transpose_lhs, int32 transpose_rhs) {
+ MatMul<double>(run_options_ptr, out, lhs, rhs, m, n, k, transpose_lhs,
+ transpose_rhs);
+}
diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h
new file mode 100644
index 0000000000..029eb95142
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
+
+#include "tensorflow/core/platform/types.h"
+
+extern "C" {
+
+// Performs a single-threaded matrix multiplication using Eigen. 'lhs' and 'rhs'
+// are pointers to buffers containing input matrices in column-major order.
+// 'out' is a pointer to a buffer sufficiently large to hold the result of the
+// operation. Following standard nomenclature: lhs is m x k, rhs is k x n, and
+// out is m x n.
+extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out,
+ float* lhs, float* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+extern void __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
+ const void* /* xla::ExecutableRunOptions* */ run_options_ptr, double* out,
+ double* lhs, double* rhs, tensorflow::int64 m, tensorflow::int64 n,
+ tensorflow::int64 k, tensorflow::int32 transpose_lhs,
+ tensorflow::int32 transpose_rhs);
+
+} // extern "C"
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_SINGLE_THREADED_MATMUL_H_
diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
new file mode 100644
index 0000000000..c938a03df7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ xla::LocalClient* client(xla::ClientLibrary::LocalClientOrDie());
+
+ // Transfer parameters.
+ std::unique_ptr<xla::Literal> param0_literal =
+ xla::LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::LiteralUtil::CreateR2<float>(
+ {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}});
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ // Build computation.
+ xla::ComputationBuilder builder(client, "");
+ auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto add = builder.Add(p1, p0, {0});
+
+ xla::StatusOr<xla::Computation> computation_status = builder.Build();
+ xla::Computation computation = computation_status.ConsumeValueOrDie();
+
+ // Execute and transfer result of computation.
+ xla::ExecutionProfile profile;
+ xla::StatusOr<std::unique_ptr<xla::Literal>> result =
+ client->ExecuteAndTransfer(
+ computation,
+ /*arguments=*/{param0_data.get(), param1_data.get()},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/&profile);
+ std::unique_ptr<xla::Literal> actual = result.ConsumeValueOrDie();
+
+ LOG(INFO) << tensorflow::strings::Printf("computation took %lldns",
+ profile.compute_time_ns());
+ LOG(INFO) << xla::LiteralUtil::ToString(*actual);
+
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
new file mode 100644
index 0000000000..7754c556a8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
+
+#include <dlfcn.h>
+#include <stdint.h>
+#include <algorithm>
+#include <list>
+#include <utility>
+
+#include "external/llvm/include/llvm/IR/Mangler.h"
+#include "external/llvm/include/llvm/Support/CodeGen.h"
+#include "external/llvm/include/llvm/Support/Host.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/cpu/compiler_functor.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
+#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace cpu {
+namespace {
+
+// Converts a symbol 'name' into the form expected by dlsym().
+std::string CanonicalizeSymbol(const std::string &name) {
+#if defined(__APPLE__)
+ // On Mac OS X, dlsym() expects names not to be prefixed with a leading
+ // underscore.
+ if (!name.empty() && name.front() == '_') {
+ return name.substr(1);
+ }
+#endif
+ return name;
+}
+
+// A simple SymbolResolver that delegates to the host dynamic linker.
+struct SimpleResolver : public llvm::JITSymbolResolver {
+ llvm::JITSymbol findSymbol(const std::string &name) override {
+ void *func_addr = nullptr;
+
+ std::string canonical_name = CanonicalizeSymbol(name);
+ if (canonical_name == runtime::kEigenMatmulF32SymbolName) {
+ func_addr = reinterpret_cast<void *>(__xla_cpu_runtime_EigenMatMulF32);
+ } else if (canonical_name ==
+ runtime::kEigenSingleThreadedMatmulF32SymbolName) {
+ func_addr = reinterpret_cast<void *>(
+ __xla_cpu_runtime_EigenSingleThreadedMatMulF32);
+ } else if (canonical_name == runtime::kEigenConvF32SymbolName) {
+ func_addr = reinterpret_cast<void *>(__xla_cpu_runtime_EigenConvF32);
+ } else if (canonical_name ==
+ runtime::kEigenSingleThreadedConvF32SymbolName) {
+ func_addr = reinterpret_cast<void *>(
+ __xla_cpu_runtime_EigenSingleThreadedConvF32);
+ } else if (canonical_name ==
+ runtime::kAcquireInfeedBufferForDequeueSymbolName) {
+ func_addr = reinterpret_cast<void *>(
+ __xla_cpu_runtime_AcquireInfeedBufferForDequeue);
+ } else if (canonical_name ==
+ runtime::kReleaseInfeedBufferAfterDequeueSymbolName) {
+ func_addr = reinterpret_cast<void *>(
+ __xla_cpu_runtime_ReleaseInfeedBufferAfterDequeue);
+ } else if (canonical_name == runtime::kExpV4F32) {
+ func_addr = reinterpret_cast<void *>(runtime::ExpV4F32);
+ } else if (canonical_name == runtime::kExpV8F32) {
+ func_addr = reinterpret_cast<void *>(runtime::ExpV8F32);
+ } else if (canonical_name == runtime::kLogV4F32) {
+ func_addr = reinterpret_cast<void *>(runtime::LogV4F32);
+ } else if (canonical_name == runtime::kLogV8F32) {
+ func_addr = reinterpret_cast<void *>(runtime::LogV8F32);
+ } else if (canonical_name == runtime::kTanhV4F32) {
+ func_addr = reinterpret_cast<void *>(runtime::TanhV4F32);
+ } else if (canonical_name == runtime::kTanhV8F32) {
+ func_addr = reinterpret_cast<void *>(runtime::TanhV8F32);
+ } else {
+ func_addr = dlsym(RTLD_DEFAULT, canonical_name.c_str());
+ }
+
+ if (func_addr == nullptr) {
+ return nullptr;
+ }
+ llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
+ llvm::JITSymbolFlags::None);
+ return symbol_info;
+ }
+ llvm::JITSymbol findSymbolInLogicalDylib(const std::string &name) override {
+ return nullptr;
+ }
+};
+
+llvm::SmallVector<std::string, 0> DetectMachineAttributes() {
+ llvm::SmallVector<std::string, 0> result;
+ llvm::StringMap<bool> host_features;
+ if (llvm::sys::getHostCPUFeatures(host_features)) {
+ for (auto &feature : host_features) {
+ if (feature.second) {
+ result.push_back(feature.first());
+ }
+ }
+ }
+ return result;
+}
+
+CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
+ CompilerFunctor::VectorIntrinsics intrinsics;
+ intrinsics.sse_intrinsics = (&runtime::ExpV4F32 != nullptr);
+ intrinsics.avx_intrinsics = (&runtime::ExpV8F32 != nullptr);
+ return intrinsics;
+}
+
+} // namespace
+
+SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions &target_options,
+ llvm::CodeGenOpt::Level opt_level)
+ : target_machine_(
+ CHECK_NOTNULL(llvm::EngineBuilder()
+ .setTargetOptions(target_options)
+ .setOptLevel(opt_level)
+ .selectTarget(
+ /*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
+ /*MCPU=*/llvm::sys::getHostCPUName(),
+ /*MAttrs=*/DetectMachineAttributes()))),
+ disassembler_(*target_machine_),
+ data_layout_(target_machine_->createDataLayout()),
+ compile_layer_(object_layer_,
+ CompilerFunctor(target_machine_.get(), &disassembler_,
+ opt_level, GetAvailableIntrinsics())) {}
+
+SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule(
+ std::unique_ptr<llvm::Module> module) {
+ // The Orc API adds a whole iterable "set" of modules, so we wrap the module
+ // in a vector.
+ std::vector<std::unique_ptr<llvm::Module>> module_set;
+ module_set.push_back(std::move(module));
+ auto handle = compile_layer_.addModuleSet(
+ std::move(module_set), MakeUnique<llvm::SectionMemoryManager>(),
+ MakeUnique<SimpleResolver>());
+ module_handles_.push_back(handle);
+ return handle;
+}
+
+void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) {
+ module_handles_.erase(
+ std::remove(module_handles_.begin(), module_handles_.end(), handle));
+ compile_layer_.removeModuleSet(handle);
+}
+
+llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string &name) {
+ std::string mangled_name;
+ {
+ llvm::raw_string_ostream mangled_name_stream(mangled_name);
+ llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, data_layout_);
+ }
+
+ // Resolve symbol from last module to first, allowing later redefinitions of
+ // symbols shadow earlier ones.
+ for (auto &handle :
+ llvm::make_range(module_handles_.rbegin(), module_handles_.rend())) {
+ if (auto symbol =
+ compile_layer_.findSymbolIn(handle, mangled_name,
+ /*ExportedSymbolsOnly=*/true)) {
+ return symbol;
+ }
+ }
+
+ return nullptr;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
new file mode 100644
index 0000000000..9d1c842e0f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h
@@ -0,0 +1,88 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "external/llvm/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h"
+#include "external/llvm/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/Target/TargetMachine.h"
+#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace cpu {
+
+// Simplified LLVM JIT based on the new Orc API.
+//
+// This class wraps Orc's functionality into a single interface that only
+// exposes what we need for XLA.
+//
+// Supports JIT-ing multiple modules but without cross-module linking.
+// Implements eager compilation - the module is lowered to binary as soon as
+// it's added to the JIT.
+class SimpleOrcJIT {
+ public:
+ using ObjLayerT = llvm::orc::ObjectLinkingLayer<>;
+ using CompileLayerT = llvm::orc::IRCompileLayer<ObjLayerT>;
+ using ModuleHandleT = CompileLayerT::ModuleSetHandleT;
+
+ // Create a new JIT, targeting the host architecture.
+ // The |target_options| parameter allows customization of certain code
+ // generation properties of the TargetMachine (whether or not float point math
+ // can be reassociated, etc.).
+ // The |opt_level| parameter controls the optimization level of the code
+ // generator.
+ SimpleOrcJIT(const llvm::TargetOptions& target_options,
+ llvm::CodeGenOpt::Level opt_level);
+
+ // Data layout this JIT was created with.
+ const llvm::DataLayout& data_layout() const { return data_layout_; }
+
+ // Target triple (host) this JIT was created with.
+ const llvm::Triple& target_triple() const {
+ return target_machine_->getTargetTriple();
+ }
+
+ // Add a module to the JIT. Returns an opaque handle that can be used to later
+ // remove this module.
+ ModuleHandleT AddModule(std::unique_ptr<llvm::Module> module);
+
+ // Remove a module from the JIT and free the memory associated with it.
+ void RemoveModule(ModuleHandleT handle);
+
+ // Get the runtime address of the compiled symbol whose name is given. Returns
+ // nullptr if the symbol cannot be found.
+ llvm::JITSymbol FindSymbol(const std::string& name);
+
+ private:
+ std::vector<ModuleHandleT> module_handles_;
+ std::unique_ptr<llvm::TargetMachine> target_machine_;
+ const Disassembler disassembler_;
+ const llvm::DataLayout data_layout_;
+ ObjLayerT object_layer_;
+ CompileLayerT compile_layer_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SIMPLE_ORC_JIT_H_
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
new file mode 100644
index 0000000000..423ec29fdc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu_transfer_manager.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
+#include "tensorflow/compiler/xla/service/cpu/infeed_manager.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+namespace {
+
+class CpuInfeedBuffer : public cpu::runtime::InfeedBuffer {
+ public:
+ explicit CpuInfeedBuffer(int32 length)
+ : length_(length),
+ buffer_(new char[length]),
+ device_memory_(buffer_, length_) {}
+ ~CpuInfeedBuffer() override { delete[] buffer_; }
+
+ int32 length() override { return length_; }
+ void* data() override { return buffer_; }
+ void Done() override { delete this; }
+
+ se::DeviceMemoryBase* device_memory() { return &device_memory_; }
+
+ private:
+ int32 length_;
+ char* buffer_;
+ se::DeviceMemoryBase device_memory_;
+};
+
+} // namespace
+
+CpuTransferManager::CpuTransferManager()
+ : GenericTransferManager(se::host::kHostPlatformId) {}
+
+Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
+ const Literal& literal) {
+ const Shape& shape = literal.shape();
+ VLOG(2) << "transferring literal shape to infeed: "
+ << ShapeUtil::HumanString(shape);
+
+ // TODO(b/31381668) handle tuples.
+ if (ShapeUtil::IsTuple(shape)) {
+ return Unimplemented("Infeed with a tuple shape is not supported: %s",
+ ShapeUtil::HumanString(literal.shape()).c_str());
+ }
+
+ cpu::runtime::InfeedManager* infeed_manager =
+ cpu::runtime::GetInfeedManager();
+
+ int64 size = GetByteSizeRequirement(shape);
+ if (size > std::numeric_limits<int32>::max()) {
+ return Unimplemented("Infeed shape is too large: %s needs %lld bytes",
+ ShapeUtil::HumanString(literal.shape()).c_str(), size);
+ }
+ int32 size_32 = static_cast<int32>(size);
+ CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32);
+ TF_RETURN_IF_ERROR(TransferBufferToDevice(
+ executor, /*size=*/size, /*source=*/LiteralUtil::InternalData(literal),
+ queued_buffer->device_memory()));
+
+ infeed_manager->EnqueueBuffer(queued_buffer);
+
+ return Status::OK();
+}
+
+} // namespace xla
+
+static xla::TransferManager* CreateCpuTransferManager() {
+ return new xla::CpuTransferManager();
+}
+
+static bool InitModule() {
+ xla::TransferManager::RegisterTransferManager(se::host::kHostPlatformId,
+ &CreateCpuTransferManager);
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h
new file mode 100644
index 0000000000..727462252d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// An implementation of the XLA GenericTransferManager that
+// handles CPU-specific infeed.
+class CpuTransferManager : public GenericTransferManager {
+ public:
+ CpuTransferManager();
+ ~CpuTransferManager() override {}
+
+ Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
+ const Literal& literal) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc
new file mode 100644
index 0000000000..1bef4e2b8c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+
+StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator(
+ perftools::gputools::Platform* platform,
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ stream_executors)
+ : DeviceMemoryAllocator(platform),
+ stream_executors_(stream_executors.begin(), stream_executors.end()) {}
+
+StatusOr<perftools::gputools::DeviceMemoryBase>
+StreamExecutorMemoryAllocator::Allocate(int device_ordinal, uint64 size,
+ bool retry_on_failure) {
+ if (size == 0) {
+ return perftools::gputools::DeviceMemoryBase(nullptr, 0);
+ }
+ TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor,
+ GetStreamExecutor(device_ordinal));
+ return stream_executor->AllocateArray<uint8>(size);
+}
+
+tensorflow::Status StreamExecutorMemoryAllocator::Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) {
+ if (!mem->is_null()) {
+ TF_ASSIGN_OR_RETURN(perftools::gputools::StreamExecutor * stream_executor,
+ GetStreamExecutor(device_ordinal));
+ // We make a local copy of 'mem' so the original is not zeroed out by the
+ // Deallocate() call below. This gives us a better chance of
+ // catching double-free bugs, since Deallocate silently succeeds for null
+ // values.
+ perftools::gputools::DeviceMemoryBase mem_copy(*mem);
+ stream_executor->Deallocate(&mem_copy);
+ }
+ return tensorflow::Status::OK();
+}
+
+StatusOr<perftools::gputools::StreamExecutor*>
+StreamExecutorMemoryAllocator::GetStreamExecutor(int device_ordinal) {
+ if (device_ordinal < 0) {
+ return InvalidArgument("device ordinal value (%d) must be non-negative",
+ device_ordinal);
+ }
+ if (device_ordinal >= stream_executors_.size()) {
+ return InvalidArgument(
+ "device ordinal value (%d) >= number of devices (%zu)", device_ordinal,
+ stream_executors_.size());
+ }
+ if (stream_executors_[device_ordinal] == nullptr) {
+ return NotFound("Device %s:%d present but not supported",
+ platform()->Name().c_str(), device_ordinal);
+ }
+ return stream_executors_[device_ordinal];
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h
new file mode 100644
index 0000000000..461cc818bf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/device_memory_allocator.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Interface for device memory allocators used within the XLA service. An
+// allocator is responsible for allocating memory on all devices of a particular
+// platform.
+class DeviceMemoryAllocator {
+ public:
+ // Parameter platform indicates which platform the allocator allocates memory
+ // on. Must be non-null.
+ explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform)
+ : platform_(platform) {}
+ virtual ~DeviceMemoryAllocator() {}
+
+ // 'retry_on_failure': If false, and the first attempt to allocate the memory
+ // fails, the allocation should return immediately without retrying.
+ // An example use case is optional scratch spaces where a failure
+ // has only performance impact.
+ // Allocate() should return a null pointer for a size-0 allocation.
+ // Deallocate() must be a no-op for null pointers.
+ virtual StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure = true) = 0;
+ virtual tensorflow::Status Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0;
+
+ // Return the platform that the allocator allocates memory on.
+ const perftools::gputools::Platform* platform() const { return platform_; }
+
+ protected:
+ const perftools::gputools::Platform* platform_;
+};
+
+// Default memory allocator for a platform which uses
+// StreamExecutor::Allocate/Deallocate.
+class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator {
+ public:
+ StreamExecutorMemoryAllocator(
+ perftools::gputools::Platform* platform,
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ stream_executors);
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure = true) override;
+ tensorflow::Status Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override;
+
+ private:
+ StatusOr<perftools::gputools::StreamExecutor*> GetStreamExecutor(
+ int device_ordinal);
+
+ // A vector indexed by device ordinal of StreamExecutors for each device of
+ // the allocator's platform type. If an element is nullptr, then the device
+ // with the respective device ordinal is not supported by XLA.
+ std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEVICE_MEMORY_ALLOCATOR_H_
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
new file mode 100644
index 0000000000..5b29686100
--- /dev/null
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo,
+ HloOpcode opcode,
+ HloInstruction* operand) {
+ return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s",
+ HloOpcodeString(opcode).c_str());
+}
+
+Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo,
+ HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s",
+ HloOpcodeString(opcode).c_str());
+}
+
+void DfsHloVisitor::SetVisiting(const HloInstruction& instruction) {
+ VLOG(3) << "marking HLO " << &instruction << " as visiting: ";
+ CHECK(NotVisited(instruction));
+ visit_state_[&instruction] = VisitState::kVisiting;
+}
+
+void DfsHloVisitor::SetVisited(const HloInstruction& instruction) {
+ VLOG(3) << "marking HLO " << &instruction << " as visited: ";
+ CHECK(NotVisited(instruction) || IsVisiting(instruction));
+ visit_state_[&instruction] = VisitState::kVisited;
+}
+
+bool DfsHloVisitor::IsVisiting(const HloInstruction& instruction) {
+ if (visit_state_.count(&instruction) == 0) {
+ return false;
+ }
+ return visit_state_[&instruction] == VisitState::kVisiting;
+}
+
+bool DfsHloVisitor::DidVisit(const HloInstruction& instruction) {
+ if (visit_state_.count(&instruction) == 0) {
+ return false;
+ }
+ return visit_state_[&instruction] == VisitState::kVisited;
+}
+
+bool DfsHloVisitor::NotVisited(const HloInstruction& instruction) {
+ return visit_state_.count(&instruction) == 0 ||
+ visit_state_[&instruction] == VisitState::kNotVisited;
+}
+
+Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); }
+
+Status DfsHloVisitor::Postprocess(HloInstruction* visited) {
+ return Status::OK();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
new file mode 100644
index 0000000000..24fb6dee84
--- /dev/null
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -0,0 +1,289 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class HloComputation;
+class HloInstruction;
+
+// A postorder depth-first HloInstruction visitor. When Handle* is called on an
+// instruction, all its operands were already visited. User code can subclass
+// this to iterate over an HloInstruction DAG. The Handle* routines have
+// operands / data unpacked for ease of use in the visitor subclass.
+//
+// No instruction will ever be visited twice; however, the root instruction will
+// be reported again when the traversal is done via a call to FinishVisit.
+//
+// A subclass must override at least
+// (either HandleElementwiseUnary or all the Handle methods for unary ops) and
+// (either HandleElementwiseBinary or all the Handle methods for binary ops)).
+// The default Handle methods for (unary, binary) ops call
+// (HandleElementwiseUnary, HandleElementwiseBinary).
+// The default (HandleElementwiseUnary, HandleElementwiseBinary) return an
+// "unimplemented" error status.
+//
+// Note: this may change to an iterator in the future for flexibility purposes.
+//
+// TODO(b/26548304): Stop passing in information about the visited
+// instruction that is accessible from the instruction object itself.
+class DfsHloVisitor {
+ public:
+ DfsHloVisitor()
+ : visit_state_(32) // Start the hash table a bit larger to avoid resizes
+ {}
+ virtual ~DfsHloVisitor() {}
+
+ // These routines are self-descriptive, see class comment for usage
+ // information.
+
+ virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode,
+ HloInstruction* operand);
+ virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs);
+ virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
+ HloInstruction* arg, HloInstruction* max) = 0;
+ virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) = 0;
+ virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(maximum, HloOpcode::kMaximum, lhs, rhs);
+ }
+ virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(minimum, HloOpcode::kMinimum, lhs, rhs);
+ }
+ virtual Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
+ virtual Status HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) {
+ return HandleElementwiseUnary(convert, HloOpcode::kConvert, operand);
+ }
+ virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) {
+ return HandleElementwiseUnary(copy, HloOpcode::kCopy, operand);
+ }
+ virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(multiply, HloOpcode::kMultiply, lhs, rhs);
+ }
+ virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
+ HloInstruction* rhs) = 0;
+ virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(power, HloOpcode::kPower, lhs, rhs);
+ }
+ virtual Status HandleConvolution(HloInstruction* convolution,
+ HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window) = 0;
+ virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0;
+ virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
+ HloInstruction* lhs, HloInstruction* rhs) {
+ return HandleElementwiseBinary(compare, opcode, lhs, rhs);
+ }
+ virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(add, HloOpcode::kAdd, lhs, rhs);
+ }
+ virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(divide, HloOpcode::kDivide, lhs, rhs);
+ }
+ virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(remainder, HloOpcode::kRemainder, lhs, rhs);
+ }
+ virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseBinary(subtract, HloOpcode::kSubtract, lhs, rhs);
+ }
+ virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) {
+ return HandleElementwiseUnary(abs, HloOpcode::kAbs, operand);
+ }
+ virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) {
+ return HandleElementwiseUnary(sign, HloOpcode::kSign, operand);
+ }
+ virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) {
+ return HandleElementwiseUnary(negate, HloOpcode::kNegate, operand);
+ }
+ virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) {
+ return HandleElementwiseUnary(exp, HloOpcode::kExp, operand);
+ }
+ virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) {
+ return HandleElementwiseUnary(floor, HloOpcode::kFloor, operand);
+ }
+ virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) {
+ return HandleElementwiseUnary(ceil, HloOpcode::kCeil, operand);
+ }
+ virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) {
+ return HandleElementwiseUnary(log, HloOpcode::kLog, operand);
+ }
+ virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) {
+ return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand);
+ }
+ virtual Status HandleLogicalAnd(HloInstruction* logical_and,
+ HloInstruction* lhs, HloInstruction* rhs) {
+ return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs,
+ rhs);
+ }
+ virtual Status HandleLogicalNot(HloInstruction* logical_not,
+ HloInstruction* operand) {
+ return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot, operand);
+ }
+ virtual Status HandleLogicalOr(HloInstruction* logical_or,
+ HloInstruction* lhs, HloInstruction* rhs) {
+ return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs);
+ }
+
+ virtual Status HandleInfeed(HloInstruction* infeed) = 0;
+ virtual Status HandleRng(HloInstruction* random,
+ RandomDistribution distribution) = 0;
+ virtual Status HandleReverse(HloInstruction* reverse,
+ HloInstruction* operand) = 0;
+ virtual Status HandleSort(HloInstruction* sort, HloInstruction* operand) = 0;
+ virtual Status HandleConstant(HloInstruction* constant,
+ const Literal& literal) = 0;
+ virtual Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) = 0;
+ virtual Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) = 0;
+ virtual Status HandleBitcast(HloInstruction* bitcast) = 0;
+ virtual Status HandleBroadcast(HloInstruction* broadcast) = 0;
+ virtual Status HandleReshape(HloInstruction* reshape) = 0;
+ virtual Status HandleTranspose(HloInstruction* transpose) = 0;
+ virtual Status HandleParameter(HloInstruction* parameter) = 0;
+ virtual Status HandleFusion(HloInstruction* fusion) = 0;
+ virtual Status HandleCall(
+ HloInstruction* call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) = 0;
+ virtual Status HandleCustomCall(
+ HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) = 0;
+ virtual Status HandleSlice(HloInstruction* slice,
+ HloInstruction* operand) = 0;
+ virtual Status HandleDynamicSlice(
+ HloInstruction* slice,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
+ virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
+ HloInstruction* operand,
+ HloInstruction* update,
+ HloInstruction* start_indices) = 0;
+ virtual Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) = 0;
+ virtual Status HandleMap(
+ HloInstruction* map,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) = 0;
+ virtual Status HandleReduceWindow(HloInstruction* reduce_window,
+ HloInstruction* operand,
+ const Window& window,
+ HloComputation* function) = 0;
+ virtual Status HandleSelectAndScatter(HloInstruction* instruction) = 0;
+ virtual Status HandleWhile(HloInstruction* xla_while, HloInstruction* init,
+ HloComputation* condition,
+ HloComputation* body) = 0;
+
+ virtual Status HandlePad(HloInstruction* pad) = 0;
+
+ virtual Status HandleSend(HloInstruction* send) = 0;
+
+ virtual Status HandleRecv(HloInstruction* recv) = 0;
+
+ // Invoked to inform the visitor that the traversal has completed, and that
+ // the root was "root".
+ virtual Status FinishVisit(HloInstruction* root) = 0;
+
+ // 3 possible visitation states of HLO instructions. Each instruction's
+ // state only flows one way: kNotVisited -> kVisiting -> kVisited.
+ enum VisitState {
+ kNotVisited,
+ kVisiting,
+ kVisited,
+ };
+
+ // Sets the visitation state of the given instruction as kVisiting.
+ //
+ // Precondition: current state must be kNotVisited.
+ void SetVisiting(const HloInstruction& instruction);
+
+ // Sets the visitation state of the given instruction as kVisited.
+ //
+ // Precondition: current state must be either kNotVisited or kVisiting.
+ void SetVisited(const HloInstruction& instruction);
+
+ // Returns whether the state of the given instruction is kVisiting.
+ bool IsVisiting(const HloInstruction& instruction);
+
+ // Returns whether the state of the given instruction is kVisited.
+ bool DidVisit(const HloInstruction& instruction);
+
+ // Returns whether the state of the given instruction is kNotVisited.
+ bool NotVisited(const HloInstruction& instruction);
+
+ // This method should be overridden by subclasses that wish to run some
+ // operation on an op before its Handle* visitor method is called.
+ //
+ // For any HLO op, the order of calls is:
+ //
+ // Preprocess(op);
+ // Handle/OpType/(op);
+ // Postprocess(op);
+ //
+ // Overriding methods should call DfsHloVisitor::Preprocess before doing their
+ // own preprocessing.
+ virtual Status Preprocess(HloInstruction* hlo);
+
+ // This method should be overridden by subclasses that wish to run some
+ // operation on an op after its Handle* visitor method is called. See
+ // Preprocess for more details.
+ //
+ // Overriding methods should call DfsHloVisitor::Postprocess after doing their
+ // own postprocessing.
+ virtual Status Postprocess(HloInstruction* visited);
+
+ private:
+ // Tracks the visitation state of each instruction. Any instructions that are
+ // not found from the map are considered as VisitState::kNotVisited.
+ tensorflow::gtl::FlatMap<const HloInstruction*, VisitState> visit_state_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitor);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
new file mode 100644
index 0000000000..4808c7a041
--- /dev/null
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -0,0 +1,226 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class HloComputation;
+class HloInstruction;
+
+// DfsHloVisitor with default action based on the HloInstruction being visited.
+class DfsHloVisitorWithDefault : public DfsHloVisitor {
+ public:
+ DfsHloVisitorWithDefault() {}
+ ~DfsHloVisitorWithDefault() override {}
+
+ // Default action performed on HloInstruction.
+ virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0;
+
+ Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode,
+ HloInstruction* operand) override {
+ return DefaultAction(hlo);
+ }
+ Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ return DefaultAction(hlo);
+ }
+ Status HandleClamp(HloInstruction* clamp, HloInstruction* /*min*/,
+ HloInstruction* /*arg*/,
+ HloInstruction* /*max*/) override {
+ return DefaultAction(clamp);
+ }
+ Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override {
+ return DefaultAction(concatenate);
+ }
+ Status HandleConvert(HloInstruction* convert,
+ HloInstruction* /*operand*/) override {
+ return DefaultAction(convert);
+ }
+ Status HandleCopy(HloInstruction* copy,
+ HloInstruction* /*operand*/) override {
+ return DefaultAction(copy);
+ }
+ Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/,
+ HloInstruction* /*on_true*/,
+ HloInstruction* /*on_false*/) override {
+ return DefaultAction(select);
+ }
+ Status HandleDot(HloInstruction* dot, HloInstruction* /*lhs*/,
+ HloInstruction* /*rhs*/) override {
+ return DefaultAction(dot);
+ }
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* /*lhs*/,
+ HloInstruction* /*rhs*/,
+ const Window& /*window*/) override {
+ return DefaultAction(convolution);
+ }
+ Status HandleCrossReplicaSum(HloInstruction* crs) override {
+ return DefaultAction(crs);
+ }
+ Status HandleCompare(HloInstruction* compare, HloOpcode /*opcode*/,
+ HloInstruction* /*lhs*/,
+ HloInstruction* /*rhs*/) override {
+ return DefaultAction(compare);
+ }
+ Status HandleRng(HloInstruction* random,
+ RandomDistribution /*distribution*/) override {
+ return DefaultAction(random);
+ }
+ Status HandleInfeed(HloInstruction* infeed) override {
+ return DefaultAction(infeed);
+ }
+ Status HandleReverse(HloInstruction* reverse,
+ HloInstruction* /*operand*/) override {
+ return DefaultAction(reverse);
+ }
+ Status HandleSort(HloInstruction* sort,
+ HloInstruction* /*operand*/) override {
+ return DefaultAction(sort);
+ }
+ Status HandleConstant(HloInstruction* constant,
+ const Literal& /*literal*/) override {
+ return DefaultAction(constant);
+ }
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* /*operand*/) override {
+ return DefaultAction(get_tuple_element);
+ }
+ Status HandleParameter(HloInstruction* parameter) override {
+ return DefaultAction(parameter);
+ }
+ Status HandleFusion(HloInstruction* fusion) override {
+ return DefaultAction(fusion);
+ }
+ Status HandleCall(HloInstruction* call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/,
+ HloComputation* /*computation*/) override {
+ return DefaultAction(call);
+ }
+ Status HandleCustomCall(
+ HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/,
+ tensorflow::StringPiece /*custom_call_target*/) override {
+ return DefaultAction(custom_call);
+ }
+ Status HandleSlice(HloInstruction* slice,
+ HloInstruction* /*operand*/) override {
+ return DefaultAction(slice);
+ }
+ Status HandleDynamicSlice(
+ HloInstruction* slice,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override {
+ return DefaultAction(slice);
+ }
+ Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
+ HloInstruction* /*operand*/,
+ HloInstruction* /*update*/,
+ HloInstruction* /*start_indices*/) override {
+ return DefaultAction(dynamic_update_slice);
+ }
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/) override {
+ return DefaultAction(tuple);
+ }
+ Status HandleMap(
+ HloInstruction* map,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*operands*/,
+ HloComputation* /*function*/,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/)
+ override {
+ return DefaultAction(map);
+ }
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* /*arg*/,
+ HloInstruction* /*init_value*/,
+ tensorflow::gtl::ArraySlice<int64> /*dimensions*/,
+ HloComputation* /*function*/) override {
+ return DefaultAction(reduce);
+ }
+ Status HandleReduceWindow(HloInstruction* reduce_window,
+ HloInstruction* /*operand*/,
+ const Window& /*window*/,
+ HloComputation* /*function*/) override {
+ return DefaultAction(reduce_window);
+ }
+ Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override {
+ return DefaultAction(select_and_scatter);
+ }
+ Status HandleBitcast(HloInstruction* bitcast) override {
+ return DefaultAction(bitcast);
+ }
+ Status HandleBroadcast(HloInstruction* broadcast) override {
+ return DefaultAction(broadcast);
+ }
+ Status HandlePad(HloInstruction* pad) override { return DefaultAction(pad); }
+ Status HandleReshape(HloInstruction* reshape) override {
+ return DefaultAction(reshape);
+ }
+ Status HandleTranspose(HloInstruction* transpose) override {
+ return DefaultAction(transpose);
+ }
+ Status HandleWhile(HloInstruction* xla_while, HloInstruction* /*init*/,
+ HloComputation* /*condition*/,
+ HloComputation* /*body*/) override {
+ return DefaultAction(xla_while);
+ }
+ Status HandleSend(HloInstruction* send) override {
+ return DefaultAction(send);
+ }
+ Status HandleRecv(HloInstruction* recv) override {
+ return DefaultAction(recv);
+ }
+
+ // Invoked to inform the visitor that the traversal has completed, and that
+ // the root was "root".
+ Status FinishVisit(HloInstruction* /*root*/) override { return Status::OK(); }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefault);
+};
+
+// Helper class for Accept(VisitorFunction) which visits instructions in DFS
+// order calling the given function at each instruction.
+class FunctionVisitor : public DfsHloVisitorWithDefault {
+ public:
+ using VisitorFunction = std::function<Status(HloInstruction*)>;
+ explicit FunctionVisitor(VisitorFunction visitor_func)
+ : visitor_func_(std::move(visitor_func)) {}
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ return visitor_func_(hlo_instruction);
+ }
+
+ private:
+ VisitorFunction visitor_func_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
new file mode 100644
index 0000000000..1a87a0043a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -0,0 +1,934 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Intrinsics.h"
+#include "external/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+using llvm_ir::IrArray;
+using llvm_ir::SetToFirstInsertPoint;
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const {
+ if (op->opcode() == HloOpcode::kCopy) {
+ return operand_value;
+ } else {
+ return operand_value->getType()->isIntegerTy()
+ ? EmitIntegerUnaryOp(op, operand_value)
+ : EmitFloatUnaryOp(op, operand_value);
+ }
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const {
+ switch (op->opcode()) {
+ case HloOpcode::kConvert: {
+ PrimitiveType from_type = op->operand(0)->shape().element_type();
+ PrimitiveType to_type = op->shape().element_type();
+ CHECK(primitive_util::IsIntegralType(from_type));
+ if (from_type == to_type) {
+ return operand_value;
+ }
+ if (primitive_util::IsIntegralType(to_type)) {
+ return ir_builder_->CreateIntCast(
+ operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_),
+ primitive_util::IsSignedIntegralType(to_type));
+ }
+ if (primitive_util::IsFloatingPointType(to_type)) {
+ if (primitive_util::IsSignedIntegralType(from_type)) {
+ return ir_builder_->CreateSIToFP(
+ operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
+ }
+ if (primitive_util::IsUnsignedIntegralType(from_type)) {
+ return ir_builder_->CreateUIToFP(
+ operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
+ }
+ }
+ return Unimplemented("conversion from primitive type %s to %s",
+ PrimitiveType_Name(from_type).c_str(),
+ PrimitiveType_Name(to_type).c_str());
+ }
+ case HloOpcode::kAbs: {
+ bool is_signed =
+ primitive_util::IsSignedIntegralType(op->shape().element_type());
+ if (is_signed) {
+ auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
+ ir_builder_);
+ auto zero = llvm::ConstantInt::get(type, 0);
+ auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
+ return ir_builder_->CreateSelect(cmp, operand_value,
+ ir_builder_->CreateNeg(operand_value));
+ } else {
+ return operand_value;
+ }
+ }
+ case HloOpcode::kSign: {
+ bool is_signed =
+ primitive_util::IsSignedIntegralType(op->shape().element_type());
+ auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(),
+ ir_builder_);
+ auto zero = llvm::ConstantInt::get(type, 0);
+ auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
+ if (is_signed) {
+ auto ashr = ir_builder_->CreateAShr(operand_value,
+ type->getIntegerBitWidth() - 1);
+ return ir_builder_->CreateSelect(cmp, zero,
+ ir_builder_->CreateOr(ashr, 1));
+ } else {
+ return ir_builder_->CreateSelect(cmp, zero,
+ llvm::ConstantInt::get(type, 1));
+ }
+ }
+ case HloOpcode::kNegate:
+ return ir_builder_->CreateNeg(operand_value);
+ case HloOpcode::kLogicalNot:
+ // It is not sufficient to just call CreateNot() here because a PRED is
+ // represented as an i8 and the truth value is stored only in the bottom
+ // bit.
+ return ir_builder_->CreateZExt(
+ ir_builder_->CreateNot(ir_builder_->CreateTrunc(
+ operand_value, ir_builder_->getInt1Ty())),
+ llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_));
+ default:
+ return Unimplemented("unary integer op '%s'",
+ HloOpcodeString(op->opcode()).c_str());
+ }
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const {
+ switch (op->opcode()) {
+ case HloOpcode::kConvert: {
+ PrimitiveType from_type = op->operand(0)->shape().element_type();
+ PrimitiveType to_type = op->shape().element_type();
+ CHECK(primitive_util::IsFloatingPointType(from_type));
+ if (from_type == to_type) {
+ return operand_value;
+ }
+ if (primitive_util::IsFloatingPointType(to_type)) {
+ return ir_builder_->CreateFPCast(
+ operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
+ }
+ if (primitive_util::IsSignedIntegralType(to_type)) {
+ return ir_builder_->CreateFPToSI(
+ operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
+ }
+ if (primitive_util::IsUnsignedIntegralType(to_type)) {
+ return ir_builder_->CreateFPToUI(
+ operand_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_));
+ }
+ return Unimplemented("unhandled conversion operation: %s => %s",
+ PrimitiveType_Name(from_type).c_str(),
+ PrimitiveType_Name(to_type).c_str());
+ }
+ case HloOpcode::kExp:
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value},
+ {operand_value->getType()},
+ ir_builder_);
+ case HloOpcode::kLog:
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value},
+ {operand_value->getType()},
+ ir_builder_);
+ case HloOpcode::kFloor:
+ return llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
+ ir_builder_);
+ case HloOpcode::kCeil:
+ return llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()},
+ ir_builder_);
+ case HloOpcode::kAbs:
+ return llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()},
+ ir_builder_);
+ case HloOpcode::kSign: {
+ // TODO(b/32151903): Ensure consistent sign behavior for -0.0
+ auto type = operand_value->getType();
+ auto zero = llvm::ConstantFP::get(type, 0.0);
+ auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
+ auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero);
+ return ir_builder_->CreateSelect(
+ oeq, zero,
+ ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
+ llvm::ConstantFP::get(type, 1.0)));
+ }
+ case HloOpcode::kNegate:
+ return ir_builder_->CreateFNeg(operand_value);
+ default:
+ return Unimplemented("unary floating-point op '%s'",
+ HloOpcodeString(op->opcode()).c_str());
+ }
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const {
+ return lhs_value->getType()->isIntegerTy()
+ ? EmitIntegerBinaryOp(op, lhs_value, rhs_value,
+ primitive_util::IsSignedIntegralType(
+ op->operand(0)->shape().element_type()))
+ : EmitFloatBinaryOp(op, lhs_value, rhs_value);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const {
+ switch (op->opcode()) {
+ case HloOpcode::kAdd:
+ return ir_builder_->CreateFAdd(lhs_value, rhs_value);
+ case HloOpcode::kSubtract:
+ return ir_builder_->CreateFSub(lhs_value, rhs_value);
+ case HloOpcode::kMultiply:
+ return ir_builder_->CreateFMul(lhs_value, rhs_value);
+ case HloOpcode::kDivide:
+ return ir_builder_->CreateFDiv(lhs_value, rhs_value);
+ case HloOpcode::kRemainder:
+ return ir_builder_->CreateFRem(lhs_value, rhs_value);
+
+ // The 'O' prefix on the LLVM ops means "ordered" compare where comparisons
+ // with NAN always return false.
+ case HloOpcode::kEq:
+ return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kNe:
+ return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_ONE, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kLt:
+ return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kGt:
+ return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kLe:
+ return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kGe:
+ return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
+ rhs_value, ir_builder_);
+
+ case HloOpcode::kMaximum:
+ return EmitFloatMax(lhs_value, rhs_value);
+ case HloOpcode::kMinimum:
+ return EmitFloatMin(lhs_value, rhs_value);
+ case HloOpcode::kPower:
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow,
+ {lhs_value, rhs_value},
+ {lhs_value->getType()}, ir_builder_);
+
+ default:
+ return Unimplemented("binary floating point op '%s'",
+ HloOpcodeString(op->opcode()).c_str());
+ }
+}
+
+llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::maxnum,
+ {lhs_value, rhs_value},
+ {lhs_value->getType()}, ir_builder_);
+}
+
+llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const {
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::minnum,
+ {lhs_value, rhs_value},
+ {lhs_value->getType()}, ir_builder_);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
+ llvm::Value* x) const {
+ if (prim_type != F32) {
+ return Unimplemented("inverse erf");
+ }
+ auto getFloat = [&](const float f) {
+ return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
+ };
+ auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
+ llvm::Value* w) {
+ llvm::Value* p = getFloat(coefficients.front());
+ coefficients.pop_front();
+ for (float coefficient : coefficients) {
+ p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w),
+ getFloat(coefficient));
+ }
+ return p;
+ };
+
+ // Approximation for inverse error function from
+ // Giles, M., "Approximating the erfinv function".
+ // The approximation has the form:
+ // w = log((1-x)*(1+x))
+ // if ( w < 5 ) {
+ // w = w - 2.5
+ // p = sum_{i=1}^n lq[i]*w^i
+ // } else {
+ // w = sqrt(w) - 3
+ // p = sum_{i=1}^n gq[i]*w^i
+ // }
+ // return p*x
+ llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()});
+
+ llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall(
+ logf_fn,
+ {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x),
+ ir_builder_->CreateFAdd(getFloat(1.0f), x))}));
+
+ llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ ir_builder_->getFloatTy(), "p.addr", ir_builder_);
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)),
+ "w_less_than_five", ir_builder_);
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ {
+ llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f));
+ tensorflow::gtl::ArraySlice<float> lq{
+ 2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
+ -4.39150654e-06f, 0.00021858087f, -0.00125372503f,
+ -0.00417768164f, 0.246640727f, 1.50140941f};
+ llvm::Value* p = multiply_add(lq, lw);
+ ir_builder_->CreateStore(p, p_addr);
+ }
+
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ {
+ llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
+ module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()});
+
+ llvm::Value* gw = ir_builder_->CreateFSub(
+ ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
+ tensorflow::gtl::ArraySlice<float> gq{
+ -0.000200214257f, 0.000100950558f, 0.00134934322f,
+ -0.00367342844f, 0.00573950773f, -0.0076224613f,
+ 0.00943887047f, 1.00167406f, 2.83297682f};
+ llvm::Value* p = multiply_add(gq, gw);
+ ir_builder_->CreateStore(p, p_addr);
+ }
+
+ SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+ llvm::Value* p = ir_builder_->CreateLoad(p_addr);
+ return ir_builder_->CreateFMul(p, x);
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ // Compute erfcinv(value) by calculating erfinv(1.0 - value).
+ auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_);
+ auto one = llvm::ConstantFP::get(type, 1.0);
+ return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
+}
+
+StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
+ bool is_signed) const {
+ switch (op->opcode()) {
+ // TODO(jingyue): add the "nsw" attribute for signed types.
+ case HloOpcode::kAdd:
+ return ir_builder_->CreateAdd(lhs_value, rhs_value);
+ case HloOpcode::kSubtract:
+ return ir_builder_->CreateSub(lhs_value, rhs_value);
+ case HloOpcode::kMultiply:
+ return ir_builder_->CreateMul(lhs_value, rhs_value);
+ case HloOpcode::kDivide:
+ return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value)
+ : ir_builder_->CreateUDiv(lhs_value, rhs_value);
+ case HloOpcode::kRemainder:
+ return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value)
+ : ir_builder_->CreateURem(lhs_value, rhs_value);
+ case HloOpcode::kEq:
+ return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kNe:
+ return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
+ rhs_value, ir_builder_);
+ case HloOpcode::kLt:
+ return llvm_ir::EmitComparison(
+ is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
+ lhs_value, rhs_value, ir_builder_);
+ case HloOpcode::kGt:
+ return llvm_ir::EmitComparison(
+ is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
+ lhs_value, rhs_value, ir_builder_);
+ case HloOpcode::kLe:
+ return llvm_ir::EmitComparison(
+ is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
+ lhs_value, rhs_value, ir_builder_);
+ case HloOpcode::kGe:
+ return llvm_ir::EmitComparison(
+ is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
+ lhs_value, rhs_value, ir_builder_);
+ case HloOpcode::kMinimum:
+ return ir_builder_->CreateSelect(
+ ir_builder_->CreateICmp(
+ is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
+ case HloOpcode::kMaximum:
+ return ir_builder_->CreateSelect(
+ ir_builder_->CreateICmp(
+ is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
+ case HloOpcode::kLogicalAnd:
+ return ir_builder_->CreateAnd(lhs_value, rhs_value);
+ case HloOpcode::kLogicalOr:
+ return ir_builder_->CreateOr(lhs_value, rhs_value);
+ default:
+ return Unimplemented("binary integer op '%s'",
+ HloOpcodeString(op->opcode()).c_str());
+ }
+}
+
+llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
+ const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
+ int64 operand_no) const {
+ CHECK(hlo.IsElementwise()) << "HLO " << hlo.ToString()
+ << " is not elementwise.";
+
+ const Shape& operand_shape = hlo.operand(operand_no)->shape();
+ // If the operand is scalar, the source index is always {}.
+ if (ShapeUtil::IsScalar(operand_shape)) {
+ return llvm_ir::IrArray::Index();
+ }
+
+ // If no implicit broadcast is needed for this operand, returns the target
+ // index as the source index.
+ if (ShapeUtil::Compatible(operand_shape, hlo.shape())) {
+ return target_index;
+ }
+
+ // If implicit broadcast is needed, the source dimensions that are broadcast
+ // have index 0.
+ CHECK_EQ(ShapeUtil::Rank(operand_shape), ShapeUtil::Rank(hlo.shape()));
+ llvm_ir::IrArray::Index source_index;
+ for (int64 i = 0; i < ShapeUtil::Rank(hlo.shape()); ++i) {
+ if (hlo.shape().dimensions(i) == operand_shape.dimensions(i)) {
+ source_index.push_back(target_index[i]);
+ } else {
+ CHECK_EQ(1, operand_shape.dimensions(i));
+ source_index.push_back(ir_builder_->getInt64(0));
+ }
+ }
+ return source_index;
+}
+
+llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
+ const {
+ PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type();
+ llvm::Type* param_ir_type =
+ llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_);
+
+ // Same values as PCG library
+ // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
+ llvm::Value* multiplier = ir_builder_->getInt(
+ llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
+ llvm::Value* increment = ir_builder_->getInt(
+ llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
+
+ auto random_value = [hlo]() {
+ CHECK(hlo->parent() != nullptr && hlo->parent()->parent() != nullptr);
+ const HloModule* module = hlo->parent()->parent();
+ return module->RandomNew64();
+ };
+
+ // Seed each RNG emitter with a new 64-bit seed from the HloModule. If the
+ // compilation order is deterministic (i.e., RandomNew64 invocation order is
+ // deterministic), then the order of RNG is deterministic for a given seed and
+ // hence tests will be deterministic.
+ // If the user provides a global seed instruction then we only use 64-bits of
+ // the host's random number generator to seed the 128 bit value with the other
+ // 64-bits is due to a user specified global seed instruction.
+ // Create a GlobalVariable to maintain state between invocations. There is a
+ // bug in NVPTX with GlobalVariable and 128 bit values, so using 2 64-bit
+ // values.
+ llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable(
+ /*M=*/*module_,
+ /*Ty=*/ir_builder_->getInt64Ty(),
+ /*isConstant=*/false,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/ir_builder_->getInt64(random_value()),
+ /*Name=*/"state_ptr0");
+ uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
+ : random_value();
+ llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
+ /*M=*/*module_,
+ /*Ty=*/ir_builder_->getInt64Ty(),
+ /*isConstant=*/false,
+ /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
+ /*Initializer=*/ir_builder_->getInt64(graph_seed),
+ /*Name=*/"state_ptr1");
+
+ // We want each thread to use its own stream, so we modify the increment per
+ // thread. We want the increment to remain odd, so we shift the thread id left
+ // 1 and add it to the increment.
+ increment = ir_builder_->CreateAdd(increment,
+ ir_builder_->CreateShl(EmitThreadId(), 1));
+
+ // PCG-XSL-RR algorithm
+ // http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
+ // state = multiplier * state + increment
+ // return uint64_t(state ^ (state >> 64))) >>> (state >> 122)
+ // where ">>>" is bitwise rotation
+ auto get_next_i64 = [=]() {
+ llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc(
+ ir_builder_->CreateLoad(state_ptr0, "state0"),
+ ir_builder_->getInt128Ty());
+ llvm::Value* state1 = ir_builder_->CreateShl(
+ ir_builder_->CreateZExtOrTrunc(
+ ir_builder_->CreateLoad(state_ptr1, "state1"),
+ ir_builder_->getInt128Ty()),
+ 64);
+ llvm::Value* state = ir_builder_->CreateOr(state0, state1);
+ llvm::Value* updated = ir_builder_->CreateAdd(
+ ir_builder_->CreateMul(state, multiplier), increment);
+ ir_builder_->CreateStore(
+ ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()),
+ state_ptr0);
+ ir_builder_->CreateStore(
+ ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64),
+ ir_builder_->getInt64Ty()),
+ state_ptr1);
+
+ return llvm_ir::CreateRor(
+ ir_builder_->CreateTrunc(
+ ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)),
+ ir_builder_->getInt64Ty()),
+ ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122),
+ ir_builder_->getInt64Ty()),
+ ir_builder_);
+ };
+
+ auto get_next_uniform_float = [=]() {
+ return ir_builder_->CreateFDiv(
+ ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type),
+ llvm::ConstantFP::get(param_ir_type, 0x1p64));
+ };
+
+ return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ switch (hlo->random_distribution()) {
+ case RNG_UNIFORM: {
+ TF_ASSIGN_OR_RETURN(llvm::Value * p,
+ operand_to_generator.at(hlo->operand(0))(index));
+ TF_ASSIGN_OR_RETURN(llvm::Value * q,
+ operand_to_generator.at(hlo->operand(1))(index));
+ if (primitive_util::IsFloatingPointType(param_prim_type)) {
+ return ir_builder_->CreateFAdd(
+ ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p),
+ get_next_uniform_float()),
+ p);
+ } else {
+ auto r = ir_builder_->CreateSub(q, p);
+ auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)},
+ {param_ir_type}, ir_builder_);
+ auto in_block = ir_builder_->GetInsertBlock();
+ auto body_block = in_block->splitBasicBlock(
+ ir_builder_->GetInsertPoint(), "rng_body");
+ SetToFirstInsertPoint(body_block, ir_builder_);
+ auto out_block = body_block->splitBasicBlock(
+ ir_builder_->GetInsertPoint(), "rng_out");
+ SetToFirstInsertPoint(body_block, ir_builder_);
+ auto random = ir_builder_->CreateAnd(
+ ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
+ ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
+ leading_zeros));
+ llvm::ReplaceInstWithInst(
+ body_block->getTerminator(),
+ llvm::BranchInst::Create(out_block, body_block,
+ ir_builder_->CreateICmpULE(random, r)));
+ SetToFirstInsertPoint(out_block, ir_builder_);
+ return ir_builder_->CreateAdd(
+ p, ir_builder_->CreateSelect(
+ ir_builder_->CreateICmpEQ(p, q),
+ llvm::ConstantInt::get(param_ir_type, 0), random));
+ }
+ }
+ case RNG_NORMAL: {
+ TF_ASSIGN_OR_RETURN(llvm::Value * m,
+ operand_to_generator.at(hlo->operand(0))(index));
+ TF_ASSIGN_OR_RETURN(llvm::Value * s,
+ operand_to_generator.at(hlo->operand(1))(index));
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * r,
+ EmitErfcInv(param_prim_type,
+ ir_builder_->CreateFMul(
+ llvm::ConstantFP::get(param_ir_type, 2.0),
+ get_next_uniform_float())));
+ return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m);
+ }
+ case RNG_BERNOULLI: {
+ TF_ASSIGN_OR_RETURN(llvm::Value * p,
+ operand_to_generator.at(hlo->operand(0))(index));
+ return ir_builder_->CreateZExt(
+ ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p),
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
+ ir_builder_));
+ }
+ default:
+ return InvalidArgument(
+ "unhandled distribution %s",
+ RandomDistribution_Name(hlo->random_distribution()).c_str());
+ }
+ };
+}
+
+llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
+ const HloInstruction* hlo,
+ const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator)
+ const {
+ // TODO(mfdyck): Make capture lists explicit, lest someone forget to cap
+ // `operand_to_generator` by ref and its many copies fill memory and cause
+ // much woe and process death.
+ switch (hlo->opcode()) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kCeil:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCopy:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kLog:
+ case HloOpcode::kNegate:
+ case HloOpcode::kSign:
+ case HloOpcode::kTanh:
+ case HloOpcode::kLogicalNot:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(0))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ return EmitUnaryOp(hlo, operand_value);
+ };
+ case HloOpcode::kAdd:
+ case HloOpcode::kDivide:
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalOr:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ const HloInstruction* lhs = hlo->operand(0);
+ const HloInstruction* rhs = hlo->operand(1);
+ TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value,
+ operand_to_generator.at(lhs)(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value,
+ operand_to_generator.at(rhs)(
+ ElementwiseSourceIndex(index, *hlo, 1)));
+ return EmitBinaryOp(hlo, lhs_value, rhs_value);
+ };
+ case HloOpcode::kSelect:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ TF_ASSIGN_OR_RETURN(llvm::Value * pred_value,
+ operand_to_generator.at(hlo->operand(0))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * on_true_value,
+ operand_to_generator.at(hlo->operand(1))(
+ ElementwiseSourceIndex(index, *hlo, 1)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
+ operand_to_generator.at(hlo->operand(2))(
+ ElementwiseSourceIndex(index, *hlo, 2)));
+ return ir_builder_->CreateSelect(
+ ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
+ on_true_value, on_false_value);
+ };
+ case HloOpcode::kClamp:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ TF_ASSIGN_OR_RETURN(llvm::Value * min_value,
+ operand_to_generator.at(hlo->operand(0))(
+ ElementwiseSourceIndex(index, *hlo, 0)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * arg_value,
+ operand_to_generator.at(hlo->operand(1))(
+ ElementwiseSourceIndex(index, *hlo, 1)));
+ TF_ASSIGN_OR_RETURN(llvm::Value * max_value,
+ operand_to_generator.at(hlo->operand(2))(
+ ElementwiseSourceIndex(index, *hlo, 2)));
+ return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value));
+ };
+ case HloOpcode::kConcatenate:
+ return [=, &operand_to_generator](
+ const IrArray::Index target_index) -> StatusOr<llvm::Value*> {
+ const int64 concat_dim = hlo->dimensions(0);
+ auto source_index = target_index;
+
+ llvm::PHINode* output = ir_builder_->CreatePHI(
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
+ ir_builder_),
+ hlo->operands().size());
+ llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
+ auto prior_insert_point = ir_builder_->GetInsertPoint();
+ llvm::BasicBlock* exit_block =
+ init_block->splitBasicBlock(output, "concat_merge");
+
+ ir_builder_->SetInsertPoint(init_block);
+ init_block->getTerminator()->eraseFromParent();
+
+ for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
+ ++operand_idx) {
+ const HloInstruction* operand = hlo->operand(operand_idx);
+ auto true_block = llvm_ir::CreateBasicBlock(
+ exit_block, tensorflow::strings::StrCat(
+ "concat_index_from_operand", operand_idx),
+ ir_builder_);
+ auto false_block = llvm_ir::CreateBasicBlock(
+ exit_block, tensorflow::strings::StrCat(
+ "concat_index_not_from_operand", operand_idx),
+ ir_builder_);
+ auto concat_dim_size =
+ llvm::ConstantInt::get(source_index[concat_dim]->getType(),
+ operand->shape().dimensions(concat_dim));
+ ir_builder_->CreateCondBr(
+ ir_builder_->CreateICmpULT(source_index[concat_dim],
+ concat_dim_size),
+ true_block, false_block);
+
+ // Create the terminator of the true block before calling operand
+ // generators, because they require non-degenerate basic blocks.
+ ir_builder_->SetInsertPoint(
+ llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
+ TF_ASSIGN_OR_RETURN(llvm::Value * value,
+ operand_to_generator.at(operand)(source_index));
+ output->addIncoming(value, ir_builder_->GetInsertBlock());
+
+ // Subtract the size of the concat dimension of the current operand
+ // from the source index.
+ ir_builder_->SetInsertPoint(false_block);
+ source_index[concat_dim] =
+ ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
+ }
+
+ ir_builder_->CreateUnreachable();
+ ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
+ return output;
+ };
+ case HloOpcode::kReverse:
+ return [=, &operand_to_generator](
+ const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+ const HloInstruction* operand = hlo->operand(0);
+ auto source_index = target_index;
+ for (int64 dim : hlo->dimensions()) {
+ source_index[dim] = ir_builder_->CreateSub(
+ llvm::ConstantInt::get(target_index[dim]->getType(),
+ hlo->shape().dimensions(dim) - 1),
+ target_index[dim]);
+ }
+ return operand_to_generator.at(operand)(source_index);
+ };
+ case HloOpcode::kBroadcast:
+ return [=, &operand_to_generator](
+ const IrArray::Index& target_index) -> StatusOr<llvm::Value*> {
+ // The `dimensions` member of the broadcast instruction maps from
+ // input dimensions to output dimensions.
+ const HloInstruction* operand = hlo->operand(0);
+ int64 rank = ShapeUtil::Rank(operand->shape());
+ IrArray::Index source_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ source_index[i] = target_index[hlo->dimensions(i)];
+ }
+ return operand_to_generator.at(operand)(source_index);
+ };
+ case HloOpcode::kSlice:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ IrArray::Index sliced_index(index.size());
+ for (int i = 0; i < index.size(); ++i) {
+ sliced_index[i] = ir_builder_->CreateAdd(
+ index[i], llvm::ConstantInt::get(index[i]->getType(),
+ hlo->slice_starts(i)));
+ }
+ return operand_to_generator.at(hlo->operand(0))(sliced_index);
+ };
+ case HloOpcode::kDynamicSlice:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ // Emit IR to read dynamic start indices from hlo->operand(1).
+ const HloInstruction* input_hlo = hlo->operand(0);
+ const int64 rank = ShapeUtil::Rank(input_hlo->shape());
+ llvm_ir::IrArray::Index slice_start_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * start_index_value,
+ operand_to_generator.at(hlo->operand(1))(dim_index));
+ slice_start_index[i] = start_index_value;
+ }
+
+ llvm_ir::IrArray::Index input_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ // Emit IR which computes:
+ // input_index = (start_index + offset_index) % dim_size
+ // Security note: this is the code that keeps the indices in-bounds.
+ llvm::Value* dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), input_hlo->shape().dimensions(i));
+ llvm::Value* start_index = ir_builder_->CreateZExtOrBitCast(
+ slice_start_index[i], index[i]->getType());
+ input_index[i] = ir_builder_->CreateURem(
+ ir_builder_->CreateAdd(start_index, index[i]), dim_size);
+ }
+ return operand_to_generator.at(input_hlo)(input_index);
+ };
+ case HloOpcode::kDynamicUpdateSlice:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ const HloInstruction* input_hlo = hlo->operand(0);
+ const HloInstruction* update_hlo = hlo->operand(1);
+ const HloInstruction* start_hlo = hlo->operand(2);
+ // Calculate slice start/end indices.
+ const int64 rank = ShapeUtil::Rank(input_hlo->shape());
+ llvm_ir::IrArray::Index slice_start_index(rank);
+ llvm_ir::IrArray::Index slice_limit_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ // Emit IR to read dynamic start indices from 'start_hlo'.
+ llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i));
+ TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value,
+ operand_to_generator.at(start_hlo)(dim_index));
+ slice_start_index[i] = ir_builder_->CreateZExtOrBitCast(
+ start_index_value, index[i]->getType());
+ // Emit IR to compute: slice_limit_index = start_index + update_dim
+ // NOTE: Although 'start_indices' is dynamic and could be
+ // out-of-range, we do not compute 'slice_limit_index' mod input dim
+ // size here, because subsequent array index calculations will be
+ // computed mod input dim size for safety.
+ llvm::Value* update_dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), update_hlo->shape().dimensions(i));
+ slice_limit_index[i] =
+ ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
+ }
+
+ // Check if 'index' intersects start/end indices.
+ llvm::Value* slice_intersection =
+ llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1);
+
+ for (int64 i = 0; i < rank; ++i) {
+ // Check that index[i] >= slice_start_index[i].
+ slice_intersection = ir_builder_->CreateAnd(
+ slice_intersection,
+ ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
+ "slice_intersection");
+
+ // Check that index[i] < slice_limit_index[i].
+ slice_intersection = ir_builder_->CreateAnd(
+ slice_intersection,
+ ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
+ "slice_intersection");
+ }
+
+ // Emit:
+ // if (slice_intersection) -> return data from 'update'.
+ // else -> return data from 'index'.
+ llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
+ ir_builder_),
+ "ret_value_addr", ir_builder_);
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ slice_intersection, "slice_intersection", ir_builder_);
+
+ // Handle true BB.
+ SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ // Compute update index for intersection case.
+ llvm_ir::IrArray::Index update_index(rank);
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* update_dim_size = llvm::ConstantInt::get(
+ index[i]->getType(), update_hlo->shape().dimensions(i));
+ // NOTE: Subtraction will be positive due to bounds checking above.
+ update_index[i] = ir_builder_->CreateURem(
+ ir_builder_->CreateSub(index[i], slice_start_index[i]),
+ update_dim_size);
+ }
+ TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
+ operand_to_generator.at(update_hlo)(update_index));
+ ir_builder_->CreateStore(true_value, ret_value_addr);
+
+ // Handle false BB.
+ SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
+ operand_to_generator.at(input_hlo)(index));
+ ir_builder_->CreateStore(false_value, ret_value_addr);
+
+ SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+ return ir_builder_->CreateLoad(ret_value_addr);
+ };
+ case HloOpcode::kReshape:
+ CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
+ ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
+ return [=, &operand_to_generator](const IrArray::Index& index) {
+ const HloInstruction* operand = hlo->operand(0);
+ return operand_to_generator.at(operand)(index.SourceIndexOfReshape(
+ hlo->shape(), operand->shape(), ir_builder_));
+ };
+ case HloOpcode::kTranspose:
+ return [=, &operand_to_generator](const IrArray::Index& target_index) {
+ return operand_to_generator.at(hlo->operand(0))(
+ target_index.SourceIndexOfTranspose(
+ hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(),
+ ir_builder_));
+ };
+ case HloOpcode::kRng:
+ return MakeRngElementGenerator(hlo, operand_to_generator);
+ default:
+ return [=, &operand_to_generator](const IrArray::Index& index) {
+ return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str());
+ };
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
new file mode 100644
index 0000000000..2576d3823e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -0,0 +1,118 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
+
+#include <unordered_map>
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+class ElementalIrEmitter {
+ public:
+ using HloToElementGeneratorMap =
+ std::unordered_map<const HloInstruction*, llvm_ir::ElementGenerator>;
+
+ ElementalIrEmitter(const HloModuleConfig& hlo_module_config,
+ llvm::Module* module, llvm::IRBuilder<>* ir_builder)
+ : ir_builder_(ir_builder),
+ module_(module),
+ hlo_module_config_(hlo_module_config) {}
+
+ virtual ~ElementalIrEmitter() {}
+
+ virtual StatusOr<llvm::Value*> EmitUnaryOp(const HloInstruction* op,
+ llvm::Value* operand_value) const;
+
+ virtual StatusOr<llvm::Value*> EmitBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const;
+
+ // Returns a function to generate an element of the output of `hlo`, given a
+ // map of functions to generate elements of its operands.
+ virtual llvm_ir::ElementGenerator MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const;
+
+ llvm::IRBuilder<>* ir_builder() const { return ir_builder_; }
+
+ protected:
+ virtual StatusOr<llvm::Value*> EmitIntegerUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const;
+
+ virtual StatusOr<llvm::Value*> EmitFloatUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const;
+
+ virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
+ llvm::Value* lhs_value,
+ llvm::Value* rhs_value,
+ bool is_signed) const;
+
+ virtual StatusOr<llvm::Value*> EmitFloatBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const;
+
+ virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const;
+
+ virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const;
+
+ virtual StatusOr<llvm::Value*> EmitErfInv(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
+ virtual StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
+ llvm::Value* value) const;
+
+ // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and
+ // the target array index, computes the source array index of its
+ // `operand_no`-th operand.
+ //
+ // Precondition: `hlo` is an elementwise op.
+ llvm_ir::IrArray::Index ElementwiseSourceIndex(
+ const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo,
+ int64 operand_no) const;
+
+ // Identifier of the thread unique among all threads on the device
+ virtual llvm::Value* EmitThreadId() const {
+ return ir_builder_->getIntN(128, 0);
+ }
+
+ llvm::IRBuilder<>* const ir_builder_;
+
+ llvm::Module* module_;
+
+ // The HloModuleConfig which gathers all settings and values which affect the
+ // compiled executable outside of the HLO code itself.
+ const HloModuleConfig& hlo_module_config_;
+
+ private:
+ // Returns a ElementGenerator for a RNG HloInstruction.
+ llvm_ir::ElementGenerator MakeRngElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
new file mode 100644
index 0000000000..5b1a5a16d1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/executable.h"
+
+#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace xla {
+
+StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
+Executable::ExecuteOnStreams(
+ tensorflow::gtl::ArraySlice<const ExecutableRunOptions> run_options,
+ tensorflow::gtl::ArraySlice<
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
+ arguments) {
+ TF_RET_CHECK(run_options.size() == arguments.size());
+
+ if (run_options.size() == 1) {
+ TF_ASSIGN_OR_RETURN(auto result,
+ ExecuteOnStream(&run_options[0], arguments[0],
+ /*hlo_execution_profile=*/nullptr));
+ return std::vector<perftools::gputools::DeviceMemoryBase>({result});
+ }
+
+ std::vector<perftools::gputools::DeviceMemoryBase> return_values(
+ run_options.size());
+ for (int64 i = 0; i < run_options.size(); ++i) {
+ // We cannot BlockHostUntilDone() on the already-launched executions in case
+ // of error, since if the executions communicate, the initially launched
+ // executions may never complete if not all executions are running.
+ TF_ASSIGN_OR_RETURN(return_values[i],
+ ExecuteAsyncOnStream(&run_options[i], arguments[i]));
+ }
+ for (const auto& options : run_options) {
+ TF_RET_CHECK(options.stream() != nullptr);
+ options.stream()->BlockHostUntilDone();
+ }
+ return return_values;
+}
+
+Status Executable::DumpSessionModule() {
+ TF_RET_CHECK(dumping());
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ const string& directory_path = flags->xla_dump_executions_to;
+ VersionedComputationHandle versioned_handle = entry_computation_handle();
+ // This filename does not include the version number because the computation
+ // is only ever executed at one version.
+ string filename = tensorflow::strings::Printf(
+ "computation_%lld__%s__execution_%lld", versioned_handle.handle.handle(),
+ session_module_->entry().name().c_str(), ++execution_count_);
+ return Executable::DumpToDirectory(directory_path, filename,
+ *session_module_);
+}
+
+/* static */ Status Executable::DumpToDirectory(
+ const string& directory_path, const string& filename,
+ const SessionModule& session_module) {
+ tensorflow::Env* env = tensorflow::Env::Default();
+ if (!env->IsDirectory(directory_path).ok()) {
+ TF_RETURN_IF_ERROR(env->CreateDir(directory_path));
+ }
+ string file_path = tensorflow::io::JoinPath(directory_path, filename);
+ return tensorflow::WriteBinaryProto(env, file_path, session_module);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
new file mode 100644
index 0000000000..373ab79ab2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+
+// A given platform's compiler will produce an Executable -- this is a uniform
+// interface that is used for launching compiled programs across platforms.
+//
+// TODO(leary) will need to extend this to support multiple streams/devices as
+// we begin to compile single programs to run on multiple devices.
+class Executable {
+ public:
+ explicit Executable(std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config)
+ : hlo_module_(std::move(hlo_module)),
+ module_config_(std::move(module_config)) {}
+ virtual ~Executable() {}
+
+ // Enqueues the compilation result on the provided stream, passing the given
+ // arguments. This call is blocking and returns after the execution is done.
+ //
+ // If the hlo_execution_profile is provided as non-nullptr, profiling will be
+ // enabled.
+ //
+ // Returns the device memory region that a successful execution would
+ // populate.
+ virtual StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ HloExecutionProfile* hlo_execution_profile) = 0;
+
+ // Overload of ExecuteOnStream which returns and takes arguments as
+ // ShapedBuffers. Used for LocalService execution.
+ virtual StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) = 0;
+
+ // Overload of which writes the result into a pre-allocated buffer
+ // (result_buffer).
+ virtual Status ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer,
+ HloExecutionProfile* hlo_execution_profile) = 0;
+
+ // Same as ExecuteOnStream(), but this call is non-blocking and returns as
+ // soon as all of the operations are enqueued for launch on the stream.
+ virtual StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments) = 0;
+
+ // Same as ExecuteOnStream(), but runs this executable on multiple
+ // streams. arguments[i] contains the arguments to the execution on
+ // run_options[i]->stream() and the returned value is at index i of the
+ // returned vector.
+ virtual StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
+ ExecuteOnStreams(
+ tensorflow::gtl::ArraySlice<const ExecutableRunOptions> run_options,
+ tensorflow::gtl::ArraySlice<
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
+ arguments);
+
+ // Returns the ExecutionProfile from executing on the device. This includes
+ // the number of cycles taken for the computation or the compilation time.
+ ExecutionProfile execution_profile() const {
+ tensorflow::mutex_lock lock(mutex_);
+ return execution_profile_;
+ }
+
+ // Returns whether this executable was compiled with HLO profilings support
+ // enabled. If not, the caller should not expect an hlo_execution_profile
+ // passed to ExecuteOnStream above to be populated during execution.
+ bool hlo_profiling_enabled() const {
+ return module_config_->hlo_profiling_enabled();
+ }
+
+ const HloModule& module() const { return *hlo_module_; }
+
+ const HloModuleConfig& module_config() const { return *module_config_; }
+
+ // Returns the versioned computation handle of the computation computed by
+ // this executable.
+ const VersionedComputationHandle& entry_computation_handle() const {
+ return hlo_module_->entry_computation_handle();
+ }
+
+ // The shape (including layout) that results from this execution. This is the
+ // shape of the DeviceMemoryBase result value in ExecuteOnStream above.
+ const Shape& result_shape() const {
+ return module_config_->entry_computation_layout().result_shape();
+ }
+
+ // Dumping helpers.
+ void set_session_module(std::unique_ptr<xla::SessionModule> session_module) {
+ session_module_ = std::move(session_module);
+ }
+ bool dumping() const { return session_module_ != nullptr; }
+ SessionModule* session_module() const { return session_module_.get(); }
+ Status DumpSessionModule();
+
+ // Dump session_module to directory_path/filename.
+ static Status DumpToDirectory(const string& directory_path,
+ const string& filename,
+ const SessionModule& session_module);
+
+ protected:
+ mutable tensorflow::mutex mutex_;
+
+ // Execution profile data on the device.
+ ExecutionProfile execution_profile_ GUARDED_BY(mutex_);
+
+ // HloModule this was compiled from. BufferAssignment keeps pointers to
+ // HloInstructions owned by the HloModule so we need to keep the HloModule
+ // around.
+ std::unique_ptr<HloModule> hlo_module_;
+
+ // The configuration used to build this executable (parameter layouts, result
+ // layout, profiling enabled, etc).
+ std::unique_ptr<HloModuleConfig> module_config_;
+
+ // SessionModule this was compiled from. Null if not dumping executions.
+ std::unique_ptr<SessionModule> session_module_;
+
+ // Execution count, used to generate a unique filename for each dumped
+ // execution.
+ int64 execution_count_ = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
new file mode 100644
index 0000000000..cf1870580c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -0,0 +1,95 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/execution_tracker.h"
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+AsyncExecution::AsyncExecution(
+ Backend* backend,
+ std::vector<std::unique_ptr<perftools::gputools::Stream>> streams,
+ const ExecutionProfile& profile, GlobalDataHandle result)
+ : backend_(CHECK_NOTNULL(backend)),
+ streams_(std::move(streams)),
+ profile_(profile),
+ result_(result) {
+ for (const auto& stream : streams_) {
+ CHECK(stream != nullptr);
+ }
+}
+
+AsyncExecution::~AsyncExecution() {
+ for (auto& stream : streams_) {
+ backend_->ReleaseStream(std::move(stream));
+ }
+}
+
+tensorflow::Status AsyncExecution::BlockUntilDone() const {
+ for (auto& stream : streams_) {
+ if (!stream->BlockHostUntilDone()) {
+ return InternalError("failed to block until done");
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
+ExecutionTracker::ExecutionTracker() : next_handle_(1) {}
+
+ExecutionHandle ExecutionTracker::Register(
+ Backend* backend,
+ std::vector<std::unique_ptr<perftools::gputools::Stream>> streams,
+ const ExecutionProfile& profile, GlobalDataHandle result) {
+ tensorflow::mutex_lock lock(execution_mutex_);
+ int64 handle = next_handle_++;
+ auto inserted = handle_to_execution_.emplace(
+ handle,
+ MakeUnique<AsyncExecution>(backend, std::move(streams), profile, result));
+ CHECK(inserted.second);
+
+ ExecutionHandle execution_handle;
+ execution_handle.set_handle(handle);
+ return execution_handle;
+}
+
+tensorflow::Status ExecutionTracker::Unregister(const ExecutionHandle& handle) {
+ tensorflow::mutex_lock lock(execution_mutex_);
+ auto it = handle_to_execution_.find(handle.handle());
+ if (it == handle_to_execution_.end()) {
+ return NotFound("no execution record for execution handle: %lld",
+ handle.handle());
+ }
+ handle_to_execution_.erase(handle.handle());
+ return tensorflow::Status::OK();
+}
+
+StatusOr<const AsyncExecution*> ExecutionTracker::Resolve(
+ const ExecutionHandle& handle) {
+ tensorflow::mutex_lock lock(execution_mutex_);
+ auto it = handle_to_execution_.find(handle.handle());
+ if (it == handle_to_execution_.end()) {
+ return NotFound("no execution record for execution handle: %lld",
+ handle.handle());
+ }
+ return it->second.get();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h
new file mode 100644
index 0000000000..99a5bb5ad9
--- /dev/null
+++ b/tensorflow/compiler/xla/service/execution_tracker.h
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
+
+#include <map>
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Represents an asynchronously launched execution. Owns the stream (from the
+// passed run_options->stream()) on which the execution is launched and releases
+// the stream when destructed.
+class AsyncExecution {
+ public:
+ AsyncExecution(
+ Backend* backend,
+ std::vector<std::unique_ptr<perftools::gputools::Stream>> streams,
+ const ExecutionProfile& profile, GlobalDataHandle result);
+
+ ~AsyncExecution();
+ tensorflow::Status BlockUntilDone() const;
+
+ const GlobalDataHandle& result() const { return result_; }
+
+ const ExecutionProfile& profile() const { return profile_; }
+
+ private:
+ // Backend to execute the computation on.
+ Backend* backend_;
+
+ // Stream on which the execution is launched.
+ std::vector<std::unique_ptr<perftools::gputools::Stream>> streams_;
+
+ // Profile object of the execution to be returned to the user.
+ ExecutionProfile profile_;
+
+ // Data handle to the result of the execution. Data represented by this handle
+ // is valid only after BlockUntilDone() is called.
+ GlobalDataHandle result_;
+};
+
+// Tracks asynchronously launched executions for the XLA service.
+class ExecutionTracker {
+ public:
+ ExecutionTracker();
+
+ // Registers an execution with its backend, streams, and data handle to the
+ // execution result. Returns a handle for the registered execution.
+ ExecutionHandle Register(
+ Backend* backend,
+ std::vector<std::unique_ptr<perftools::gputools::Stream>> stream,
+ const ExecutionProfile& profile, GlobalDataHandle data);
+
+ // Unregisters the execution for the given handle.
+ tensorflow::Status Unregister(const ExecutionHandle& handle);
+
+ // Resolves the given ExecutionHandle to an AsyncExecution. Returns an
+ // error status if the given handle is not found, which means that the
+ // execution is not yet registered or already unregistered.
+ StatusOr<const AsyncExecution*> Resolve(const ExecutionHandle& handle);
+
+ private:
+ // The next handle to assign to an execution.
+ int64 next_handle_ GUARDED_BY(execution_mutex_);
+
+ // Mapping from ExecutionHandle handle to the corresponding registered
+ // AsyncExecution object.
+ std::map<int64, std::unique_ptr<AsyncExecution>> handle_to_execution_
+ GUARDED_BY(execution_mutex_);
+
+ tensorflow::mutex execution_mutex_; // Guards the execution mapping.
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExecutionTracker);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTION_TRACKER_H_
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
new file mode 100644
index 0000000000..086306696d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -0,0 +1,183 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id)
+ : platform_id_(platform_id) {
+ // We currently only support kHostPlatformId for CPU and kCudaPlatformId for
+ // GPU. Before supporting other platforms, we need to test this transfer
+ // manager on them.
+ CHECK(platform_id_ == se::host::kHostPlatformId ||
+ platform_id_ == se::cuda::kCudaPlatformId);
+}
+
+se::Platform::Id GenericTransferManager::PlatformId() const {
+ if (platform_id_ == se::cuda::kCudaPlatformId ||
+ platform_id_ == se::host::kHostPlatformId) {
+ return platform_id_;
+ }
+ CHECK(false) << "GenericTransferManager::platform_id_ is invalid";
+}
+
+Status GenericTransferManager::TransferLiteralFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& device_shape, const Shape& literal_shape, Literal* literal) {
+ VLOG(2) << "transferring literal shape from device: "
+ << ShapeUtil::HumanString(literal_shape)
+ << "; device location: " << source.opaque();
+ TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape));
+
+ // Tuples are a special case and contain one or more shapes inside of them to
+ // an arbitrary nesting depth.
+ if (device_shape.element_type() == TUPLE) {
+ *literal->mutable_shape() = literal_shape;
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ ShallowCopyTupleFromDevice(executor, source, device_shape));
+ TF_RET_CHECK(element_buffers.size() ==
+ ShapeUtil::TupleElementCount(device_shape));
+ for (int64 i = 0; i < element_buffers.size(); ++i) {
+ const Shape& element_device_shape = device_shape.tuple_shapes(i);
+ const Shape& element_literal_shape = literal_shape.tuple_shapes(i);
+ Literal* element_literal = literal->add_tuple_literals();
+ // Recursively call TransferFromDevice to copy over the data in the
+ // element array.
+ TF_RETURN_IF_ERROR(TransferLiteralFromDevice(
+ executor, element_buffers[i], /*device_shape=*/element_device_shape,
+ /*literal_shape=*/element_literal_shape, element_literal));
+ }
+ return Status::OK();
+ }
+
+ *literal->mutable_shape() = device_shape;
+ LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal);
+ TF_RETURN_IF_ERROR(TransferBufferFromDevice(
+ executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape),
+ /*destination=*/LiteralUtil::MutableInternalData(literal)));
+ if (!ShapeUtil::Equal(literal_shape, device_shape)) {
+ literal->Swap(
+ LiteralUtil::Relayout(*literal, literal_shape.layout()).get());
+ }
+ TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape()));
+ return Status::OK();
+}
+
+StatusOr<std::vector<se::DeviceMemoryBase>>
+GenericTransferManager::ShallowCopyTupleFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(shape));
+
+ // For devices which use the GenericTransferManager, a tuple is stored as an
+ // array of pointers to buffers. Copy the contents of the tuple buffer into
+ // a vector of void* pointers.
+ std::vector<void*> element_pointers(ShapeUtil::TupleElementCount(shape),
+ nullptr);
+ int64 tuple_size = ShapeUtil::ByteSizeOf(shape);
+ auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size,
+ element_pointers.data());
+ if (!copy_status.ok()) {
+ return AddStatus(
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape));
+ }
+
+ // Create a DeviceMemoryBase from each void* pointer.
+ std::vector<se::DeviceMemoryBase> destination;
+ for (int i = 0; i < element_pointers.size(); ++i) {
+ if (element_pointers[i] == nullptr &&
+ !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
+ return FailedPrecondition("tuple contains nullptr at element %d", i);
+ }
+ int64 buffer_size = ShapeUtil::ByteSizeOf(shape.tuple_shapes(i));
+ destination.emplace_back(element_pointers[i], buffer_size);
+ }
+ return std::move(destination);
+}
+
+Status GenericTransferManager::TransferLiteralToDevice(
+ se::StreamExecutor* executor, const Literal& literal,
+ se::DeviceMemoryBase* destination) {
+ const Shape& shape = literal.shape();
+ VLOG(2) << "transferring literal shape to device: "
+ << ShapeUtil::HumanString(shape)
+ << "; device location: " << destination->opaque();
+
+ if (ShapeUtil::IsTuple(literal.shape())) {
+ std::vector<void*> tuple_elements_on_device;
+ for (const Literal& tuple_element : literal.tuple_literals()) {
+ se::DeviceMemoryBase allocation = executor->AllocateArray<uint8>(
+ GetByteSizeRequirement(tuple_element.shape()));
+ TF_RETURN_IF_ERROR(
+ TransferLiteralToDevice(executor, tuple_element, &allocation));
+ tuple_elements_on_device.push_back(allocation.opaque());
+ }
+ return TransferBufferToDevice(
+ executor, tuple_elements_on_device.size() * sizeof(void*),
+ tuple_elements_on_device.data(), destination);
+ }
+
+ return TransferBufferToDevice(
+ executor, /*size=*/GetByteSizeRequirement(shape),
+ /*source=*/LiteralUtil::InternalData(literal), destination);
+}
+
+Status GenericTransferManager::TransferLiteralToInfeed(
+ se::StreamExecutor* executor, const Literal& literal) {
+ return Unimplemented("Infeed is not supported on GPU (b/30467474)");
+}
+
+Status GenericTransferManager::ResetDevice(se::StreamExecutor* executor) {
+ return Unimplemented(
+ "Device reset is not yet supported on CPU and GPU (b/30481585)");
+}
+
+int64 GenericTransferManager::GetByteSizeRequirement(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape);
+}
+
+} // namespace xla
+
+static xla::TransferManager* CreateGenericTransferManager() {
+ return new xla::GenericTransferManager(se::cuda::kCudaPlatformId);
+}
+
+static bool InitModule() {
+ xla::TransferManager::RegisterTransferManager(se::cuda::kCudaPlatformId,
+ CreateGenericTransferManager);
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
new file mode 100644
index 0000000000..cfa02bf22f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A generic implementation of the XLA TransferManager interface
+// that is the base class for both CPU and GPU. For GPU, it transfers
+// data between host and device (GPU). For CPU, since the "device"
+// here is the host itself, there's not much for this transfer manager
+// to do except memcpy the result. There is a CpuTransferManager that
+// inherits from GenericTransferManager and handles CPU-specific
+// infeed.
+class GenericTransferManager : public TransferManager {
+ public:
+ explicit GenericTransferManager(
+ perftools::gputools::Platform::Id platform_id);
+ ~GenericTransferManager() override {}
+
+ perftools::gputools::Platform::Id PlatformId() const override;
+
+ Status TransferLiteralFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& source,
+ const Shape& device_shape, const Shape& literal_shape,
+ Literal* literal) override;
+
+ Status TransferLiteralToDevice(
+ perftools::gputools::StreamExecutor* executor, const Literal& literal,
+ perftools::gputools::DeviceMemoryBase* destination) override;
+
+ Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor,
+ const Literal& literal) override;
+
+ Status ResetDevice(perftools::gputools::StreamExecutor* executor) override;
+
+ StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
+ ShallowCopyTupleFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& source,
+ const Shape& shape) override;
+
+ int64 GetByteSizeRequirement(const Shape& shape) override;
+
+ private:
+ // The platform this transfer manager targets.
+ perftools::gputools::Platform::Id platform_id_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GenericTransferManager);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GENERIC_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
new file mode 100644
index 0000000000..9aeebe42f8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -0,0 +1,533 @@
+# Description:
+# GPU-specific components in XLA service implementation.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [":friends"])
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "partition_assignment",
+ srcs = [
+ "partition_assignment.cc",
+ ],
+ hdrs = [
+ "partition_assignment.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+# TODO(b/29140563) This target is flaky, disabled until flakiness is
+# root-caused. Failed on 2016-06-08.
+#cc_test(
+# name = "partition_assignment_test",
+# srcs = [
+# "partition_assignment_test.cc",
+# ],
+# tags = [
+# "requires-gpu-sm35",
+# ],
+# deps = [
+# ":partition_assignment",
+# "//tensorflow/core:stream_executor_no_cuda",
+# "//tensorflow/compiler/xla:shape_util",
+# "//tensorflow/compiler/xla:xla_data_proto",
+# "//tensorflow/compiler/xla/service:gpu_plugin",
+# "//tensorflow/compiler/xla/service:hlo",
+# "//tensorflow/compiler/xla/tests:hlo_test_base",
+# "//tensorflow/core:test_main",
+# ],
+#)
+
+cc_library(
+ name = "stream_assignment",
+ srcs = ["stream_assignment.cc"],
+ hdrs = ["stream_assignment.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:stream_assignment_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "stream_assignment_test",
+ srcs = [
+ "stream_assignment_test.cc",
+ ],
+ deps = [
+ ":stream_assignment",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_to_ir_bindings",
+ srcs = ["hlo_to_ir_bindings.cc"],
+ hdrs = ["hlo_to_ir_bindings.h"],
+ deps = [
+ ":ir_emission_utils",
+ ":temp_buffer_offsets",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:ops",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "ir_emitter",
+ srcs = [
+ "ir_emitter.cc",
+ "ir_emitter_nested.cc",
+ "ir_emitter_unnested.cc",
+ ],
+ hdrs = [
+ "ir_emitter.h",
+ "ir_emitter_context.h",
+ ],
+ deps = [
+ ":elemental_ir_emitter",
+ ":gpu_executable",
+ ":hlo_to_ir_bindings",
+ ":ir_emission_utils",
+ ":parallel_loop_emitter",
+ ":partition_assignment",
+ ":temp_buffer_offsets",
+ ":while_transformer",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:elemental_ir_emitter",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:name_uniquer",
+ "//tensorflow/compiler/xla/service/llvm_ir:fused_ir_emitter",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/compiler/xla/service/llvm_ir:ops",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:core",
+ "@llvm//:support",
+ ],
+)
+
+cc_library(
+ name = "parallel_loop_emitter",
+ srcs = ["parallel_loop_emitter.cc"],
+ hdrs = ["parallel_loop_emitter.h"],
+ deps = [
+ ":partition_assignment",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "elemental_ir_emitter",
+ srcs = ["elemental_ir_emitter.cc"],
+ hdrs = ["elemental_ir_emitter.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:elemental_ir_emitter",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ "@llvm//:support",
+ ],
+)
+
+cc_library(
+ name = "temp_buffer_offsets",
+ srcs = ["temp_buffer_offsets.cc"],
+ hdrs = ["temp_buffer_offsets.h"],
+ deps = [
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "buffer_allocations",
+ srcs = ["buffer_allocations.cc"],
+ hdrs = ["buffer_allocations.h"],
+ deps = [
+ ":temp_buffer_offsets",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+cc_library(
+ name = "gpu_executable",
+ srcs = [
+ "convolution_thunk.cc",
+ "copy_thunk.cc",
+ "for_thunk.cc",
+ "gemm_thunk.cc",
+ "gpu_executable.cc",
+ "kernel_thunk.cc",
+ "sequential_thunk.cc",
+ "thunk_schedule.cc",
+ "tuple_thunk.cc",
+ "while_thunk.cc",
+ ],
+ hdrs = [
+ "convolution_thunk.h",
+ "copy_thunk.h",
+ "for_thunk.h",
+ "gemm_thunk.h",
+ "gpu_executable.h",
+ "kernel_thunk.h",
+ "sequential_thunk.h",
+ "thunk.h",
+ "thunk_schedule.h",
+ "tuple_thunk.h",
+ "while_thunk.h",
+ ],
+ deps = [
+ ":buffer_allocations",
+ ":partition_assignment",
+ ":stream_assignment",
+ ":temp_buffer_offsets",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:convolution_thunk_flags",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/compiler/xla/service:tuple_points_to_analysis",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core/platform/default/build_config:stream_executor_cuda",
+ ],
+)
+
+cc_library(
+ name = "ir_emission_utils",
+ srcs = ["ir_emission_utils.cc"],
+ hdrs = ["ir_emission_utils.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "convolution_folding",
+ srcs = ["convolution_folding.cc"],
+ hdrs = ["convolution_folding.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "convolution_folding_test",
+ srcs = ["convolution_folding_test.cc"],
+ deps = [
+ ":convolution_folding",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:shape_inference",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "instruction_fusion",
+ srcs = ["instruction_fusion.cc"],
+ hdrs = ["instruction_fusion.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:instruction_fusion",
+ ],
+)
+
+cc_test(
+ name = "instruction_fusion_test",
+ srcs = ["instruction_fusion_test.cc"],
+ deps = [
+ ":instruction_fusion",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "copy_insertion",
+ srcs = ["copy_insertion.cc"],
+ hdrs = ["copy_insertion.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla/service:copy_insertion",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/compiler/xla/service:tuple_points_to_analysis",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "pad_insertion",
+ srcs = ["pad_insertion.cc"],
+ hdrs = ["pad_insertion.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:window_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:shape_inference",
+ ],
+)
+
+cc_library(
+ name = "gpu_compiler",
+ srcs = ["gpu_compiler.cc"],
+ hdrs = ["gpu_compiler.h"],
+ deps = [
+ ":convolution_folding",
+ ":copy_insertion",
+ ":gpu_executable",
+ ":hlo_schedule",
+ ":instruction_fusion",
+ ":ir_emission_utils",
+ ":ir_emitter",
+ ":layout_assignment",
+ ":pad_insertion",
+ ":partition_assignment",
+ ":stream_assignment",
+ ":temp_buffer_offsets",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:gpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:algebraic_simplifier",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:buffer_liveness",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_cse",
+ "//tensorflow/compiler/xla/service:hlo_dce",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:hlo_pass",
+ "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
+ "//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
+ "//tensorflow/compiler/xla/service:reshape_mover",
+ "//tensorflow/compiler/xla/service:transpose_folding",
+ "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:cuda_libdevice_path",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "@llvm//:core",
+ "@llvm//:support",
+ ],
+ alwayslink = True, # Contains compiler registration
+)
+
+cc_library(
+ name = "layout_assignment",
+ srcs = ["layout_assignment.cc"],
+ hdrs = ["layout_assignment.h"],
+ deps = [
+ ":ir_emission_utils",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:layout_assignment",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "layout_assignment_test",
+ srcs = ["layout_assignment_test.cc"],
+ deps = [
+ ":layout_assignment",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "hlo_schedule",
+ srcs = ["hlo_schedule.cc"],
+ hdrs = ["hlo_schedule.h"],
+ deps = [
+ ":stream_assignment",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:buffer_liveness",
+ "//tensorflow/compiler/xla/service:hlo",
+ ],
+)
+
+cc_test(
+ name = "hlo_schedule_test",
+ srcs = [
+ "hlo_schedule_test.cc",
+ ],
+ deps = [
+ ":hlo_schedule",
+ ":stream_assignment",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "while_transformer",
+ srcs = ["while_transformer.cc"],
+ hdrs = ["while_transformer.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "while_transformer_test",
+ srcs = ["while_transformer_test.cc"],
+ deps = [
+ ":while_transformer",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
new file mode 100644
index 0000000000..a9975de3f1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc
@@ -0,0 +1,139 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+void BufferAllocations::Builder::RegisterBuffer(BufferAllocation::Index index,
+ se::DeviceMemoryBase address) {
+ InsertOrDie(&registered_buffers_, index, address);
+}
+
+StatusOr<std::unique_ptr<BufferAllocations>> BufferAllocations::Builder::Build(
+ const BufferAssignment& buffer_assignment,
+ const TempBufferOffsets& temp_buffer_offsets, int device_ordinal,
+ DeviceMemoryAllocator* memory_allocator) {
+ se::DeviceMemoryBase temp_buffer_base;
+ if (temp_buffer_offsets.TotalSizeInBytes() > 0) {
+ TF_ASSIGN_OR_RETURN(
+ temp_buffer_base,
+ memory_allocator->Allocate(device_ordinal,
+ temp_buffer_offsets.TotalSizeInBytes()));
+ if (temp_buffer_base == nullptr) {
+ return ResourceExhausted(
+ "Out of memory when allocating %s bytes for temporary buffers.",
+ tensorflow::strings::HumanReadableNumBytes(
+ temp_buffer_offsets.TotalSizeInBytes())
+ .c_str());
+ }
+ }
+ auto buffer_allocations = WrapUnique(new BufferAllocations(
+ buffer_assignment.Allocations().size(), temp_buffer_base, device_ordinal,
+ memory_allocator));
+
+ int64 num_buffers = buffer_assignment.Allocations().size();
+ for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
+ // If buffer #i's address is already registered (e.g. external arguments or
+ // result buffers), use that registered buffer.
+ if (registered_buffers_.count(i)) {
+ buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i));
+ continue;
+ }
+
+ const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+ if (allocation.maybe_live_out()) {
+ auto buffer_size = allocation.size();
+ se::DeviceMemoryBase buffer_address;
+ if (buffer_size > 0) {
+ // If the buffer escapes, we need to allocate it separately instead of
+ // merging it into the memory block for temporary buffers.
+ TF_ASSIGN_OR_RETURN(buffer_address, memory_allocator->Allocate(
+ device_ordinal, buffer_size));
+ if (buffer_address == nullptr) {
+ return ResourceExhausted(
+ "Out of memory when allocating %s for buffer %lld.",
+ tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(),
+ i);
+ }
+ }
+ buffer_allocations->SetBuffer(i, buffer_address);
+ } else if (allocation.IsPreallocatedTempBuffer()) {
+ se::DeviceMemoryBase temp_buffer_address(
+ /*opaque=*/static_cast<char*>(
+ buffer_allocations->GetTempBufferBase().opaque()) +
+ temp_buffer_offsets.GetOffset(i),
+ /*size=*/allocation.size());
+ buffer_allocations->SetBuffer(i, temp_buffer_address);
+ }
+ }
+
+ return std::move(buffer_allocations);
+}
+
+tensorflow::Status BufferAllocations::TearDown(
+ const std::set<se::DeviceMemoryBase>& live_addresses,
+ const BufferAssignment& buffer_assignment) {
+ // Deallocate temporary buffers.
+ for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) {
+ const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+ se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
+ if (allocation.maybe_live_out() && !live_addresses.count(buffer_address)) {
+ // Deallocate buffers that marked "maybe_live_out" but is not actually
+ // live out.
+ TF_RETURN_IF_ERROR(
+ memory_allocator_->Deallocate(device_ordinal_, &buffer_address));
+ }
+ }
+
+ // Deallocate the memory block for temporary buffers.
+ if (temp_buffer_base_ != nullptr) {
+ TF_RETURN_IF_ERROR(
+ memory_allocator_->Deallocate(device_ordinal_, &temp_buffer_base_));
+ }
+ return tensorflow::Status::OK();
+}
+
+se::DeviceMemoryBase BufferAllocations::GetDeviceAddress(
+ BufferAllocation::Index buffer_index) const {
+ CHECK_GE(buffer_index, 0);
+ CHECK_LT(buffer_index, buffers_.size());
+ return buffers_[buffer_index];
+}
+
+void BufferAllocations::SetBuffer(BufferAllocation::Index buffer_index,
+ se::DeviceMemoryBase buffer) {
+ CHECK_GE(buffer_index, 0);
+ CHECK_LT(buffer_index, buffers_.size());
+ buffers_[buffer_index] = buffer;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.h b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
new file mode 100644
index 0000000000..a0cd6cac01
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.h
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_
+
+#include <memory>
+#include <set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A thread-compatible class that encapsulates the base addresses of the
+// allocated device buffers.
+class BufferAllocations {
+ public:
+ // This inner class encapsulates methods that build a BufferAllocations from
+ // the given buffer assignment.
+ class Builder {
+ public:
+ // Registers preallocated buffers (such as parameter addresses and
+ // user-specified result buffers) to the given buffer index. The builder
+ // will skip allocating buffers for registered buffer indices.
+ void RegisterBuffer(BufferAllocation::Index index,
+ perftools::gputools::DeviceMemoryBase address);
+
+ // Builds a BufferAllocations object from the given buffer assignment.
+ // `memory_allocator` is what this function uses to allocate device memory.
+ // `device_ordinal` is the number of the device this function allocates
+ // memory on.
+ StatusOr<std::unique_ptr<BufferAllocations>> Build(
+ const BufferAssignment& buffer_assignment,
+ const TempBufferOffsets& temp_buffer_offsets, int device_ordinal,
+ DeviceMemoryAllocator* memory_allocator);
+
+ private:
+ std::map<BufferAllocation::Index, perftools::gputools::DeviceMemoryBase>
+ registered_buffers_;
+ };
+
+ BufferAllocations(const BufferAllocations&) = delete;
+ BufferAllocations& operator=(const BufferAllocations&) = delete;
+
+ DeviceMemoryAllocator* memory_allocator() const { return memory_allocator_; }
+ int device_ordinal() const { return device_ordinal_; }
+
+ // Returns the device address of buffer `buffer_index`. `buffer_index` must be
+ // a valid index, i.e., in [0, buffer_count). This function returns null if
+ // `buffer_index` is not assigned to a buffer address.
+ perftools::gputools::DeviceMemoryBase GetDeviceAddress(
+ BufferAllocation::Index buffer_index) const;
+
+ perftools::gputools::DeviceMemoryBase GetTempBufferBase() const {
+ return temp_buffer_base_;
+ }
+
+ // Tears down all buffers allocated by this object that are not in
+ // `live_addresses`.
+ tensorflow::Status TearDown(
+ const std::set<perftools::gputools::DeviceMemoryBase>& live_addresses,
+ const BufferAssignment& buffer_assignment);
+
+ private:
+ BufferAllocations(BufferAllocation::Index buffer_count,
+ perftools::gputools::DeviceMemoryBase temp_buffer_base,
+ int device_ordinal, DeviceMemoryAllocator* memory_allocator)
+ : buffers_(buffer_count),
+ temp_buffer_base_(
+ perftools::gputools::DeviceMemory<void*>(temp_buffer_base)),
+ device_ordinal_(device_ordinal),
+ memory_allocator_(memory_allocator) {}
+
+ // Sets the device address of buffer `buffer_index`.
+ void SetBuffer(BufferAllocation::Index buffer_index,
+ perftools::gputools::DeviceMemoryBase buffer);
+
+ // An array of device pointers that stores the address of each buffer
+ // indexed by Index. Each element can point to a temporary buffer, an
+ // input buffer, or nullptr if no buffer is needed for that Index.
+ std::vector<perftools::gputools::DeviceMemoryBase> buffers_;
+
+ // The base address of the memory block that contains all temporary buffers.
+ perftools::gputools::DeviceMemory<void*> temp_buffer_base_;
+
+ int device_ordinal_;
+
+ DeviceMemoryAllocator* memory_allocator_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_BUFFER_ALLOCATIONS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
new file mode 100644
index 0000000000..dd1b09c6cc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
@@ -0,0 +1,443 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
+
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+// Try to match a backward filter pattern that contains "conv".
+// Precondition: "conv" is a kConvolution.
+std::tuple<bool, std::vector<HloInstruction*>, Window,
+ ConvolutionDimensionNumbers>
+MatchBackwardFilter(HloInstruction* conv) {
+ const auto no_match_result =
+ std::make_tuple(false, std::vector<HloInstruction*>(), Window(),
+ ConvolutionDimensionNumbers());
+ // Step 1: match the instruction pattern without considering the paddings and
+ // dimension numbers just yet. We may need some generic pattern matcher
+ // similar to external/llvm/include/llvm/IR/PatternMatch.h
+ //
+ // Backward filter convolution is implemented in XLA as the forward
+ // convolution of padded activations and dilated gradients. Padding on
+ // activations and dilation on gradients are specified in the "window" field
+ // of the forward convolution.
+ //
+ // activations gradients
+ // \ /
+ // v v
+ // Convolution
+ // conv
+ // |
+ // v
+ // Transpose (optional if identity transposition)
+ CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
+ // If the forward convolution is followed by a transpose, we can fuse the
+ // transpose into the backward convolution as well.
+ HloInstruction* transpose = nullptr;
+ if (conv->user_count() == 1) {
+ HloInstruction* single_user = *conv->users().begin();
+ if (single_user->opcode() == HloOpcode::kTranspose) {
+ transpose = single_user;
+ }
+ }
+
+ // Step 2: match paddings and dimension numbers of the forward convolution.
+ const ConvolutionDimensionNumbers& conv_dnums =
+ conv->convolution_dimension_numbers();
+ auto batch_dim = conv_dnums.batch_dimension();
+ auto feature_dim = conv_dnums.feature_dimension();
+ auto spatial_dims = conv_dnums.spatial_dimensions();
+
+ for (const WindowDimension& window_dim : conv->window().dimensions()) {
+ if (window_dim.stride() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have stride of 1.";
+ return no_match_result;
+ }
+ if (window_dim.base_dilation() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have no base (LHS) dilation.";
+ return no_match_result;
+ }
+ if (window_dim.padding_low() < 0) {
+ VLOG(1) << "Padding low should be non-negative.";
+ return no_match_result;
+ }
+ // Padding high will be checked in Step 3.
+ }
+ if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) {
+ VLOG(1) << conv->ToString()
+ << " is a regular forward convolution. No need "
+ "to fold it to a backward filter convolution.";
+ return no_match_result;
+ }
+
+ // Step 3: fuse the matched HLOs into a backward convolution instruction.
+ //
+ // Compute the window of the backward convolution.
+ Window backward_conv_window;
+ for (int i = 0; i < 2; ++i) {
+ WindowDimension* dim = backward_conv_window.add_dimensions();
+ // The window size of the backward convolution equals the output size of the
+ // forward convolution.
+ int64 filter_size = conv->shape().dimensions(spatial_dims[i]);
+ dim->set_size(filter_size);
+ // The window stride equals the window dilation of the forward convolution.
+ dim->set_stride(conv->window().dimensions(i).window_dilation());
+ // The window's low padding is the same as the low padding of the
+ // activations.
+ dim->set_padding_low(conv->window().dimensions(i).padding_low());
+
+ int64 input_size = conv->operand(0)->shape().dimensions(spatial_dims[i]);
+ int64 output_size = conv->window().dimensions(i).size();
+ // Compute the range of the amount of valid high padding. We first compute
+ // min_padding_high, the amount of padding on the right/bottom to ensure the
+ // last patch ends at the border, i.e.,
+ //
+ // input_size + dim->padding_low() + min_padding_high
+ // = (output_size - 1) * stride + filter_size
+ //
+ // Because convolution ignores trailing incomplete windows, any amount of
+ // padding high from min_padding_high to min_padding_high+stride-1
+ // (max_padding_high) has the same effect.
+ int64 padded_input_size = filter_size + (output_size - 1) * dim->stride();
+ int64 min_padding_high =
+ padded_input_size - input_size - dim->padding_low();
+ int64 max_padding_high = min_padding_high + dim->stride() - 1;
+ CHECK_GE(dim->padding_low(), 0);
+ // In practice, since cuDNN convolution only supports even padding, we make
+ // the amount of high padding the same as the amount of low padding as long
+ // as it is between min_padding_high and max_padding_high. If it is not in
+ // that range, we pick the one that's closest to dim->padding_low() and let
+ // PadInsertion canonicalize the resultant backward convolution later.
+ // Picking the closest one minimizes the cost of the kPad instruction to be
+ // inserted by PadInsertion.
+ if (dim->padding_low() >= min_padding_high &&
+ dim->padding_low() <= max_padding_high) {
+ dim->set_padding_high(dim->padding_low());
+ } else {
+ if (dim->padding_low() < min_padding_high) {
+ dim->set_padding_high(min_padding_high);
+ } else {
+ dim->set_padding_high(max_padding_high);
+ }
+ }
+ if (dim->padding_high() < 0) {
+ LOG(ERROR)
+ << "Fusing this pattern to backward filter convolution would cause "
+ "negative padding ("
+ << dim->padding_high()
+ << ") on right/bottom of the weight gradients, which is not "
+ "supported by PadInsertion (b/32744257). Falling back to "
+ "unfused convolution for instruction: "
+ << conv->ToString();
+ return no_match_result;
+ }
+ }
+
+ // To make future HLO passes easier, we canonicalize the fused expression by
+ // adding an identity transposition if it's omitted in the pattern.
+ if (transpose == nullptr) {
+ // Create an identity transposition with the same rank as the forward
+ // convolution.
+ HloComputation* parent_computation = conv->parent();
+ std::vector<int64> transpose_dimensions(ShapeUtil::Rank(conv->shape()));
+ std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0);
+ transpose =
+ parent_computation->AddInstruction(HloInstruction::CreateTranspose(
+ conv->shape(), conv, transpose_dimensions));
+ parent_computation->ReplaceUsesOfInstruction(conv, transpose);
+ }
+
+ // Restore the dimension numbers of the backward convolution from the forward
+ // convolution. The two activation dimensions are reversed (batch and
+ // feature).
+ ConvolutionDimensionNumbers backward_conv_dnums;
+ backward_conv_dnums.set_batch_dimension(feature_dim);
+ backward_conv_dnums.set_feature_dimension(batch_dim);
+ for (int i = 0; i < 2; ++i) {
+ backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]);
+ }
+ // The dimension numbering of the output of the forward convolution (before
+ // transposition) is the same as that of the activations (according to the
+ // semantics of kConvolution). The batch dimension of the activations should
+ // be treated as the input feature dimension, and the feature dimension should
+ // be treated as the output feature.
+ //
+ // The output of the forward convolution needs to be transposed to fit into
+ // the dimension numbering of the weight gradients. This transposition maps
+ // dimension i to PositionInContainer(transpose->dimensions(), i).
+ backward_conv_dnums.set_kernel_input_feature_dimension(
+ PositionInContainer(transpose->dimensions(), batch_dim));
+ backward_conv_dnums.set_kernel_output_feature_dimension(
+ PositionInContainer(transpose->dimensions(), feature_dim));
+ for (int i = 0; i < 2; ++i) {
+ backward_conv_dnums.add_kernel_spatial_dimensions(
+ PositionInContainer(transpose->dimensions(), spatial_dims[i]));
+ }
+
+ return std::make_tuple(true, std::vector<HloInstruction*>({transpose, conv}),
+ backward_conv_window, backward_conv_dnums);
+}
+
+// Try to match a backward input pattern that contains "conv".
+// Precondition: "conv" is a kConvolution.
+std::tuple<bool, std::vector<HloInstruction*>, Window,
+ ConvolutionDimensionNumbers>
+MatchBackwardInput(HloInstruction* conv) {
+ const auto no_match_result =
+ std::make_tuple(false, std::vector<HloInstruction*>(), Window(),
+ ConvolutionDimensionNumbers());
+
+ // Match instruction pattern.
+ CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
+ HloInstruction* reverse_filter = conv->mutable_operand(1);
+
+ // Match the reverse of the filter.
+ ConvolutionDimensionNumbers dnums = conv->convolution_dimension_numbers();
+ const auto& kernel_spatial_dims = dnums.kernel_spatial_dimensions();
+ if (reverse_filter->opcode() == HloOpcode::kReverse) {
+ if (kernel_spatial_dims.size() != reverse_filter->dimensions().size() ||
+ !std::is_permutation(kernel_spatial_dims.begin(),
+ kernel_spatial_dims.end(),
+ reverse_filter->dimensions().begin())) {
+ VLOG(1)
+ << "Backward input convolution should reverse all kernel dimensions.";
+ return no_match_result;
+ }
+ } else {
+ // Possibly 1x1 filter.
+ for (int64 i = 0; i < kernel_spatial_dims.size(); ++i) {
+ if (conv->window().dimensions(i).size() != 1) {
+ VLOG(1) << "The reverse filter is neither a kReverse nor a 1x1 filter: "
+ << reverse_filter->ToString();
+ return no_match_result;
+ }
+ }
+ if (!window_util::HasBaseDilation(conv->window())) {
+ VLOG(1) << conv->ToString()
+ << " is a regular forward convolution. No need "
+ "to fold it to a backward input convolution.";
+ return no_match_result;
+ }
+ }
+
+ // Match padding and dilation of the forward convolution.
+ for (const WindowDimension& window_dim : conv->window().dimensions()) {
+ if (window_dim.stride() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have stride of 1.";
+ return no_match_result;
+ }
+ if (window_dim.window_dilation() != 1) {
+ VLOG(1) << "Forward convolution's window "
+ << conv->window().ShortDebugString()
+ << " should have no window dilation.";
+ return no_match_result;
+ }
+ }
+
+ const auto& spatial_dims = dnums.spatial_dimensions();
+ CHECK_EQ(conv->window().dimensions().size(), spatial_dims.size());
+
+ const Window& old_window = conv->window();
+ Window new_window = old_window;
+ for (size_t i = 0; i < spatial_dims.size(); ++i) {
+ // Restore backward convolution's padding config from the matched pattern.
+ // See the comment in tensorflow/core/kernels/conv_grad_ops.cc
+ // for how we convert backward input convolution to a variant of forward
+ // convolution.
+ //
+ // The stride of the backward convolution
+ // = the base dilation factor of the forward convolution
+ auto dim = new_window.mutable_dimensions(i);
+ dim->set_stride(old_window.dimensions(i).base_dilation());
+
+ // The low padding = kernel_size - 1 - low padding on the gradients
+ // Make sure the low padding is not negative.
+ auto kernel_size = old_window.dimensions(i).size();
+ auto backward_padding_low =
+ kernel_size - 1 - old_window.dimensions(i).padding_low();
+ if (backward_padding_low < 0) {
+ LOG(ERROR)
+ << "The low padding of the backward convolution would be negative ("
+ << backward_padding_low
+ << "), which isn't supported by PadInsertion for now (b/32744257).";
+ return no_match_result;
+ }
+ dim->set_padding_low(backward_padding_low);
+
+ // Compute the range of the amount of padding on the right/bottom of the
+ // activations. XLA's convolution requires all patches to be within the
+ // padded base. This gives us flexiblity to choose the amount of high
+ // padding from a set of values without changing the result of the backward
+ // convolution. The minimum amount (min_padding_high) makes the last patch
+ // end at the border. The maximum amount (max_padding_high) equals
+ // min_padding_high+stride-1 -- max_padding_high+1 would cause the output
+ // size to change.
+ auto unpadded_input_size = conv->shape().dimensions(spatial_dims[i]);
+ auto output_size = conv->operand(0)->shape().dimensions(spatial_dims[i]);
+ auto padded_input_size = kernel_size + dim->stride() * (output_size - 1);
+ auto total_pad_size = padded_input_size - unpadded_input_size;
+ auto min_padding_high = total_pad_size - backward_padding_low;
+ auto max_padding_high = min_padding_high + dim->stride() - 1;
+
+ if (backward_padding_low >= min_padding_high &&
+ backward_padding_low <= max_padding_high) {
+ // In the best case (most likely), if backward_padding_low is in the range
+ // of the amounts of valid high padding, we choose backward_padding_low
+ // because cuDNN supports even padding only.
+ dim->set_padding_high(backward_padding_low);
+ } else {
+ // Otherwise, we choose the amount that's closest to backward_padding_low,
+ // and PadInsertion will later insert kSlice instructions to enforce even
+ // padding.
+ //
+ // For example, consider the backward convolution pattern
+ //
+ // ab xy
+ // | pad | reverse
+ // .a.b yx
+ // \ /
+ // ABC
+ //
+ // The amount of low padding on activations (in backward convolution) is
+ // backward_padding_low = kernel_size - 1 - forward_padding_low
+ // = 2 - 1 - 1 = 0
+ //
+ // The amount of padding high must be between 1 and 2, in order to make
+ // Conv(ABC, xy, stride=2) produce exactly 2 elements (ab). 0 is not in
+ // the range of [1,2], so we pick the closest valid amount of padding
+ // high, which is 1 in this case. Therefore, we fuse the above pattern to
+ //
+ // ABC = BackwardInputConv(ab, xy, stride=2, padding_high=1)
+ if (backward_padding_low < min_padding_high) {
+ dim->set_padding_high(min_padding_high);
+ } else {
+ dim->set_padding_high(max_padding_high);
+ }
+ }
+ // PadInsertion doesn't handle backward input convolution with negative
+ // padding for now. So fall back to unfused convolution in case of negative
+ // padding. For example,
+ // ABCD = Conv(abc, reverse(xy), padding_high=2)
+ // could be fused to
+ // ABCD = BackwardInputConv(abc, xy, padding_low=1, padding_high=-1)
+ // with positive padding low but negative padding high.
+ if (dim->padding_high() < 0) {
+ LOG(ERROR) << "Fusing this pattern to backward convolution would cause "
+ "negative padding ("
+ << dim->padding_high()
+ << ") on right/bottom of the activations, which is not "
+ "supported by PadInsertion (b/32744257). Falling back to "
+ "unfused convolution for instruction: "
+ << conv->ToString();
+ return no_match_result;
+ }
+ }
+
+ // Fuse the matched HLOs into a backward convolution instruction.
+ //
+ // If the reverse is omitted (for 1x1 filters) in the original pattern, we add
+ // it back in the fusion instruction so that later passes (such as
+ // PadInsertion) can handle such fusion instructions easily.
+ if (reverse_filter->opcode() != HloOpcode::kReverse) {
+ reverse_filter = reverse_filter->parent()->AddInstruction(
+ HloInstruction::CreateReverse(reverse_filter->shape(), reverse_filter,
+ AsInt64Slice(kernel_spatial_dims)));
+ conv->ReplaceOperandWith(/*operand_no=*/1, reverse_filter);
+ }
+ dnums.set_kernel_input_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_output_feature_dimension());
+ dnums.set_kernel_output_feature_dimension(
+ conv->convolution_dimension_numbers().kernel_input_feature_dimension());
+
+ return std::make_tuple(true,
+ std::vector<HloInstruction*>({conv, reverse_filter}),
+ new_window, dnums);
+}
+} // namespace
+
+StatusOr<bool> ConvolutionFolding::Run(HloModule* module) {
+ HloComputation* entry_computation = module->entry_computation();
+ std::vector<HloInstruction*> convs;
+ for (const auto& hlo : entry_computation->instructions()) {
+ if (hlo->opcode() == HloOpcode::kConvolution) {
+ convs.push_back(hlo.get());
+ }
+ }
+
+ bool changed = false;
+ for (HloInstruction* conv : convs) {
+ bool match;
+ std::vector<HloInstruction*> hlos_to_fuse;
+ Window window;
+ ConvolutionDimensionNumbers dnums;
+ std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv);
+ if (match) {
+ VLOG(2) << "Fuse instructions";
+ for (HloInstruction* hlo_to_fuse : hlos_to_fuse) {
+ VLOG(2) << " " << hlo_to_fuse->ToString();
+ }
+ HloInstruction* backward_convolution =
+ entry_computation->CreateFusionInstructionForBackwardConvolution(
+ hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter,
+ window, dnums);
+ VLOG(2) << "to backward filter convolution";
+ VLOG(2) << " " << backward_convolution->ToString();
+ changed = true;
+ continue;
+ }
+
+ std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv);
+ if (match) {
+ VLOG(2) << "Fuse instructions";
+ for (HloInstruction* hlo_to_fuse : hlos_to_fuse) {
+ VLOG(2) << " " << hlo_to_fuse->ToString();
+ }
+ HloInstruction* backward_convolution =
+ entry_computation->CreateFusionInstructionForBackwardConvolution(
+ hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput,
+ window, dnums);
+ VLOG(2) << "to backward input convolution";
+ VLOG(2) << " " << backward_convolution->ToString();
+ changed = true;
+ continue;
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.h b/tensorflow/compiler/xla/service/gpu/convolution_folding.h
new file mode 100644
index 0000000000..e0233228c7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+namespace gpu {
+
+class ConvolutionFolding : public HloPass {
+ public:
+ ConvolutionFolding() : HloPass("convolution-folding") {}
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
new file mode 100644
index 0000000000..83922cbe14
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
@@ -0,0 +1,552 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+class ConvolutionFoldingTest : public HloTestBase {
+ public:
+ ConvolutionFoldingTest() {
+ for (int i = 0; i < 2; ++i) {
+ WindowDimension* window_dim = default_conv_window_.add_dimensions();
+ window_dim->set_size(1);
+ window_dim->set_stride(1);
+ window_dim->set_padding_low(0);
+ window_dim->set_padding_high(0);
+ window_dim->set_window_dilation(1);
+ window_dim->set_base_dilation(1);
+ }
+ // TF data shapes are by default in the NHWC order, and filter shape is by
+ // default in HWIO order. For backward filter convolution, we need to swap
+ // the batch and feature dimension in the activations, and treat the batch
+ // dimension in gradients as the input feature dimension in the filter.
+ //
+ // TODO(jingyue): Add more tests on NCHW input order which TF also supports.
+ tf_default_dnums_for_backward_filter_.set_batch_dimension(3);
+ tf_default_dnums_for_backward_filter_.set_feature_dimension(0);
+ tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2);
+ tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
+ tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
+ 3);
+ tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
+
+ tf_default_dnums_for_backward_input_.set_batch_dimension(0);
+ tf_default_dnums_for_backward_input_.set_feature_dimension(3);
+ tf_default_dnums_for_backward_input_.add_spatial_dimensions(1);
+ tf_default_dnums_for_backward_input_.add_spatial_dimensions(2);
+ tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3);
+ tf_default_dnums_for_backward_input_.set_kernel_output_feature_dimension(2);
+ tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(0);
+ tf_default_dnums_for_backward_input_.add_kernel_spatial_dimensions(1);
+ }
+
+ protected:
+ bool FoldConvolution(HloModule* module) {
+ ConvolutionFolding convolution_folding;
+ return convolution_folding.Run(module).ValueOrDie();
+ }
+
+ // A convolution window with stride 1 and zero padding. The size fields are
+ // not set.
+ Window default_conv_window_;
+ ConvolutionDimensionNumbers tf_default_dnums_for_backward_filter_;
+ ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
+};
+
+TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 2, 1}), "gradients"));
+ Window conv_window = default_conv_window_;
+ conv_window.mutable_dimensions(1)->set_size(2);
+ conv_window.mutable_dimensions(1)->set_window_dilation(2);
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(activations->shape(),
+ gradients->shape(), conv_window,
+ tf_default_dnums_for_backward_filter_)
+ .ConsumeValueOrDie(),
+ activations, gradients, conv_window,
+ tf_default_dnums_for_backward_filter_));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
+ entry_computation->root_instruction()->fusion_kind());
+}
+
+TEST_F(ConvolutionFoldingTest,
+ BackwardFilterConvolveEquivalentToForwardConvolution) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "gradients"));
+ Window conv_window = default_conv_window_;
+ conv_window.mutable_dimensions(1)->set_size(3);
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(activations->shape(),
+ gradients->shape(), conv_window,
+ tf_default_dnums_for_backward_filter_)
+ .ConsumeValueOrDie(),
+ activations, gradients, conv_window,
+ tf_default_dnums_for_backward_filter_));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(&module));
+}
+
+// Extracted from block35 training.
+TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(35);
+ conv_window.mutable_dimensions(i)->set_padding_low(1);
+ conv_window.mutable_dimensions(i)->set_padding_high(1);
+ }
+ HloInstruction* convolution =
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
+ conv_window, tf_default_dnums_for_backward_filter_));
+
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0}));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
+ entry_computation->root_instruction()->fusion_kind());
+}
+
+// Extracted from inception v3 training.
+TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "gradients"));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(4);
+ conv_window.mutable_dimensions(i)->set_padding_high(-1);
+ conv_window.mutable_dimensions(i)->set_window_dilation(2);
+ }
+ HloInstruction* convolution =
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
+ conv_window, tf_default_dnums_for_backward_filter_));
+
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0}));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
+ entry_computation->root_instruction()->fusion_kind());
+}
+
+TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* activations =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "activations"));
+ HloInstruction* gradients =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {20, 35, 35, 32}), "gradients"));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(35);
+ // Uneven padding: padding_low=0, padding_high=1
+ conv_window.mutable_dimensions(i)->set_padding_high(1);
+ }
+ HloInstruction* convolution =
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
+ conv_window, tf_default_dnums_for_backward_filter_));
+
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0}));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter ==
+ entry_computation->root_instruction()->fusion_kind());
+}
+
+TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {4, 5, 16, 16}), "output"));
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {5, 3, 7, 7}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {2, 3}));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(7);
+ conv_window.mutable_dimensions(i)->set_padding_low(3);
+ conv_window.mutable_dimensions(i)->set_padding_high(3);
+ }
+ ConvolutionDimensionNumbers conv_dnums;
+ conv_dnums.set_batch_dimension(0);
+ conv_dnums.set_feature_dimension(1);
+ conv_dnums.add_spatial_dimensions(2);
+ conv_dnums.add_spatial_dimensions(3);
+ conv_dnums.set_kernel_input_feature_dimension(0);
+ conv_dnums.set_kernel_output_feature_dimension(1);
+ conv_dnums.add_kernel_spatial_dimensions(2);
+ conv_dnums.add_kernel_spatial_dimensions(3);
+
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {4, 3, 16, 16}), /*lhs=*/output,
+ /*rhs=*/reverse_kernel, conv_window, conv_dnums));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(),
+ ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), conv_window, conv_dnums)
+ .ValueOrDie()));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
+ entry_computation->root_instruction()->fusion_kind());
+ for (int i = 0; i < 2; ++i) {
+ const WindowDimension& window_dim =
+ entry_computation->root_instruction()->window().dimensions(i);
+ // Low padding of the backward input convolution
+ // = kernel_size - 1 - low padding on gradients.
+ EXPECT_EQ(3, window_dim.padding_low());
+ EXPECT_EQ(3, window_dim.padding_high());
+ EXPECT_EQ(1, window_dim.stride());
+ }
+}
+
+// Convolve([abc], [x], base_dilation=2)
+// = Convolve([abc], Reverse([x]), base_dilation=2)
+// = BackwardInputConvolve([abc], [x], stride=2)
+TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) {
+ auto builder = HloComputation::Builder(TestName());
+ // NHWC dimension order.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+ // HWOI dimension order.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
+
+ Window conv_window = default_conv_window_;
+ conv_window.mutable_dimensions(1)->set_base_dilation(2);
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
+ conv_window,
+ tf_default_dnums_for_backward_input_)
+ .ConsumeValueOrDie(),
+ /*lhs=*/output, /*rhs=*/kernel, conv_window,
+ tf_default_dnums_for_backward_input_));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
+ entry_computation->root_instruction()->fusion_kind());
+}
+
+// BackwardInputConvolve([abc], [x], stride=1) is equivalent to
+// ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input
+// convolution.
+TEST_F(ConvolutionFoldingTest,
+ BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) {
+ auto builder = HloComputation::Builder(TestName());
+ // NHWC dimension order.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+ // HWOI dimension order.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 1}), "kernel"));
+
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(output->shape(), kernel->shape(),
+ default_conv_window_,
+ tf_default_dnums_for_backward_input_)
+ .ConsumeValueOrDie(),
+ /*lhs=*/output, /*rhs=*/kernel, default_conv_window_,
+ tf_default_dnums_for_backward_input_));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(&module));
+}
+
+// Extracted from Inception V3 training.
+//
+// filter(HWIO)
+// 3x3x192x320
+// |
+// v
+// gradients(NHWC) reverse
+// 20x4x4x320 3x3x192x320
+// \ /
+// \ /
+// conv (NHWC) with padding (low=2,high=3,interior=1)
+// 20x10x10x192
+//
+// Gradients are padded unevenly.
+TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(3);
+ conv_window.mutable_dimensions(i)->set_padding_low(2);
+ conv_window.mutable_dimensions(i)->set_padding_high(3);
+ // Interior padding = 1.
+ conv_window.mutable_dimensions(i)->set_base_dilation(2);
+ }
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
+ conv_window, tf_default_dnums_for_backward_input_));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), conv_window,
+ tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
+
+ HloModule module(TestName());
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ EXPECT_EQ(HloOpcode::kFusion,
+ entry_computation->root_instruction()->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
+ entry_computation->root_instruction()->fusion_kind());
+ for (int i = 0; i < 2; ++i) {
+ const WindowDimension& window_dim =
+ entry_computation->root_instruction()->window().dimensions(i);
+ EXPECT_EQ(0, window_dim.padding_low());
+ EXPECT_EQ(0, window_dim.padding_high());
+ EXPECT_EQ(2, window_dim.stride());
+ }
+}
+
+// Similar to BackwardInputConvolveUnevenPadding, but the low padding of the
+// gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused.
+TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {20, 4, 4, 320}), "output"));
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ for (int i = 0; i < 2; ++i) {
+ conv_window.mutable_dimensions(i)->set_size(3);
+ conv_window.mutable_dimensions(i)->set_padding_low(3);
+ conv_window.mutable_dimensions(i)->set_padding_high(2);
+ conv_window.mutable_dimensions(i)->set_base_dilation(2);
+ }
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {20, 10, 10, 192}), output, reverse_kernel,
+ conv_window, tf_default_dnums_for_backward_input_));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), conv_window,
+ tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(&module));
+}
+
+// Extracted from //learning/brain/google/xla/benchmarks/resnet.py
+//
+// For simplicity, we focus on the column dimension and ignore other dimensions.
+// We use [?] to represent the shape instead of the content.
+//
+// Suppose operator FC does
+// [4] = conv([14], [3], stride=2, padding_high=1) // Padding::kSame
+//
+// BC = BackwardInput(FC) does:
+// [14] = conv([7], reverse([3]),
+// padding_low=2, padding_high=1, base_dilation=2)
+//
+// We should fuse BC even though padding on activations is uneven, because
+// PadInsertion will canonicalize the fusion HLO.
+TEST_F(ConvolutionFoldingTest,
+ BackwardInputConvolveUnevenPaddingOnActivations) {
+ auto builder = HloComputation::Builder(TestName());
+ // The gradients are in NCHW layout.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 7, 1}), "output"));
+ // The kernel is in HWIO layout.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 3, 1, 1}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
+ forward_conv_col_dim->set_size(3);
+ forward_conv_col_dim->set_padding_low(2);
+ forward_conv_col_dim->set_padding_high(1);
+ forward_conv_col_dim->set_base_dilation(2);
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {1, 1, 14, 1}), output, reverse_kernel,
+ conv_window, tf_default_dnums_for_backward_input_));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), conv_window,
+ tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
+
+ HloModule module(TestName());
+ const HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build());
+ EXPECT_TRUE(FoldConvolution(&module));
+ const HloInstruction* backward_conv = entry_computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode());
+ EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput ==
+ backward_conv->fusion_kind());
+ const WindowDimension& backward_conv_col_dim =
+ backward_conv->window().dimensions(1);
+ EXPECT_EQ(0, backward_conv_col_dim.padding_low());
+ EXPECT_EQ(1, backward_conv_col_dim.padding_high());
+}
+
+// For simplicity, we focus on the column dimension and ignore other dimensions.
+// We use [?] to represent the shape instead of the content.
+//
+// Suppose operator FC does
+// [3] = conv([4], [2], padding_low=1, padding_high=-1)
+//
+// BC = BackwardInput(FC) does:
+// [4] = conv([3], reverse([2]), padding_high=2)
+//
+// We currently don't fuse BC because PadInsertion doesn't support negative
+// padding on the gradients of backward convolution (b/32744257).
+TEST_F(ConvolutionFoldingTest,
+ BackwardInputConvolveNegativePaddingHighOnActivations) {
+ auto builder = HloComputation::Builder(TestName());
+ // The gradients are in NCHW layout.
+ HloInstruction* output =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 3, 1}), "output"));
+ // The kernel is in HWIO layout.
+ HloInstruction* kernel =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 2, 1, 1}), "kernel"));
+ HloInstruction* reverse_kernel = builder.AddInstruction(
+ HloInstruction::CreateReverse(kernel->shape(), kernel, {0, 1}));
+
+ Window conv_window = default_conv_window_;
+ WindowDimension* forward_conv_col_dim = conv_window.mutable_dimensions(1);
+ forward_conv_col_dim->set_size(2);
+ forward_conv_col_dim->set_padding_high(2);
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {1, 1, 4, 1}), output, reverse_kernel,
+ conv_window, tf_default_dnums_for_backward_input_));
+ // Verify the convolution's shape is consistent with ShapeInference.
+ CHECK(ShapeUtil::Compatible(
+ conv->shape(), ShapeInference::InferConvolveShape(
+ output->shape(), reverse_kernel->shape(), conv_window,
+ tf_default_dnums_for_backward_input_)
+ .ValueOrDie()));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+ EXPECT_FALSE(FoldConvolution(&module));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
new file mode 100644
index 0000000000..30a92ab313
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc
@@ -0,0 +1,324 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/legacy_flags/convolution_thunk_flags.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+using Index = BufferAllocation::Index;
+using se::dnn::BatchDescriptor;
+using se::dnn::ConvolutionDescriptor;
+using se::dnn::DataLayout;
+using se::dnn::FilterDescriptor;
+using se::dnn::FilterLayout;
+
+ConvolveScratchAllocator::ConvolveScratchAllocator(
+ int device_ordinal, DeviceMemoryAllocator* memory_allocator)
+ : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
+
+ConvolveScratchAllocator::~ConvolveScratchAllocator() {
+ for (auto& allocated_buffer : allocated_buffers_) {
+ if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer)
+ .ok()) {
+ // The program can still continue with failed deallocation.
+ LOG(ERROR) << "Failed to deallocate the allocated buffer: "
+ << allocated_buffer.opaque();
+ }
+ }
+}
+
+int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) {
+ constexpr int64 kConvolveScratchSize = 1LL << 32; // 4GB by default.
+ return kConvolveScratchSize;
+}
+
+se::port::StatusOr<se::DeviceMemory<uint8>>
+ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) {
+ CHECK_GE(byte_size, 0) << "byte_size must be positive.";
+ if (byte_size > GetMemoryLimitInBytes(stream)) {
+ return se::port::Status(
+ se::port::error::RESOURCE_EXHAUSTED,
+ tensorflow::strings::Printf(
+ "Allocating %lld bytes exceeds the memory limit of %lld bytes.",
+ byte_size, GetMemoryLimitInBytes(stream)));
+ }
+
+ auto status_or_memory =
+ memory_allocator_->Allocate(device_ordinal_, byte_size,
+ /*retry_on_failure=*/false);
+ if (!status_or_memory.ok()) {
+ return se::port::Status(se::port::error::RESOURCE_EXHAUSTED,
+ tensorflow::strings::Printf(
+ "Failed to allocate %lld bytes on device %d.",
+ byte_size, device_ordinal_));
+ }
+ se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie();
+ allocated_buffers_.push_back(allocated_buffer);
+ total_allocated_bytes_ += byte_size;
+ return se::DeviceMemory<uint8>(allocated_buffer);
+}
+
+string ConvolutionKindToString(
+ ConvolutionThunk::ConvolutionKind convolution_kind) {
+ switch (convolution_kind) {
+ case ConvolutionThunk::ConvolutionKind::kForward:
+ return "forward";
+ case ConvolutionThunk::ConvolutionKind::kBackwardFilter:
+ return "backward_filter";
+ case ConvolutionThunk::ConvolutionKind::kBackwardInput:
+ return "backward_input";
+ }
+}
+
+ConvolutionThunk::ConvolutionThunk(
+ ConvolutionKind convolution_kind, Index input_buffer, Index filter_buffer,
+ Index output_buffer, const Shape& input_shape, const Shape& filter_shape,
+ const Shape& output_shape, const Window& window,
+ const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo)
+ : Thunk(Kind::kConvolution, hlo),
+ convolution_kind_(convolution_kind),
+ input_buffer_(input_buffer),
+ filter_buffer_(filter_buffer),
+ output_buffer_(output_buffer),
+ input_shape_(input_shape),
+ filter_shape_(filter_shape),
+ output_shape_(output_shape),
+ window_(window),
+ dim_nums_(dim_nums) {}
+
+tensorflow::Status ConvolutionThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_);
+ VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }";
+ VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }";
+ VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }";
+ VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }";
+ VLOG(3) << "Window: { " << window_.ShortDebugString() << " }";
+
+ CHECK_EQ(F32, output_shape_.element_type());
+ CHECK_EQ(2, window_.dimensions_size());
+ for (const WindowDimension& dim : window_.dimensions()) {
+ CHECK_EQ(dim.padding_low(), dim.padding_high());
+ }
+
+ const WindowDimension& height = window_.dimensions(0);
+ const WindowDimension& width = window_.dimensions(1);
+ // cuDNN's convolution APIs support the BDYX layout for activations/output and
+ // the OIYX layout for weights.
+ // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls
+ // when we switch to cuDNN v5.
+ BatchDescriptor input_descriptor;
+ input_descriptor.set_layout(DataLayout::kBatchDepthYX)
+ .set_height(input_shape_.dimensions(dim_nums_.spatial_dimensions(0)))
+ .set_width(input_shape_.dimensions(dim_nums_.spatial_dimensions(1)))
+ .set_feature_map_count(
+ input_shape_.dimensions(dim_nums_.feature_dimension()))
+ .set_count(input_shape_.dimensions(dim_nums_.batch_dimension()));
+
+ FilterDescriptor filter_descriptor;
+ filter_descriptor.set_layout(FilterLayout::kOutputInputYX)
+ .set_input_feature_map_count(
+ filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension()))
+ .set_output_feature_map_count(
+ filter_shape_.dimensions(dim_nums_.kernel_output_feature_dimension()))
+ .set_input_filter_height(
+ filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(0)))
+ .set_input_filter_width(
+ filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(1)));
+
+ ConvolutionDescriptor convolution_descriptor;
+ convolution_descriptor.set_zero_padding_width(width.padding_low())
+ .set_zero_padding_height(height.padding_low())
+ .set_horizontal_filter_stride(width.stride())
+ .set_vertical_filter_stride(height.stride());
+
+ BatchDescriptor output_descriptor;
+ output_descriptor.set_layout(DataLayout::kBatchDepthYX)
+ .set_height(output_shape_.dimensions(dim_nums_.spatial_dimensions(0)))
+ .set_width(output_shape_.dimensions(dim_nums_.spatial_dimensions(1)))
+ .set_feature_map_count(
+ output_shape_.dimensions(dim_nums_.feature_dimension()))
+ .set_count(output_shape_.dimensions(dim_nums_.batch_dimension()));
+
+ se::DeviceMemory<float> input_data(
+ buffer_allocations.GetDeviceAddress(input_buffer_));
+ se::DeviceMemory<float> filter_data(
+ buffer_allocations.GetDeviceAddress(filter_buffer_));
+ se::DeviceMemory<float> output_data(
+ buffer_allocations.GetDeviceAddress(output_buffer_));
+ return ConvolveWithTune(input_descriptor, input_data, filter_descriptor,
+ filter_data, output_descriptor, output_data,
+ convolution_descriptor, buffer_allocations, stream);
+}
+
+tensorflow::Status ConvolutionThunk::Convolve(
+ const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
+ const FilterDescriptor& filter_descriptor,
+ se::DeviceMemory<float> filter_data,
+ const BatchDescriptor& output_descriptor,
+ se::DeviceMemory<float> output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream,
+ ConvolveScratchAllocator* scratch_allocator,
+ se::dnn::ProfileResult* profile_result) {
+ bool launch_ok;
+ switch (convolution_kind_) {
+ case ConvolutionKind::kBackwardFilter:
+ launch_ok =
+ stream
+ ->ThenConvolveBackwardFilterWithAlgorithm(
+ input_descriptor, input_data, output_descriptor, output_data,
+ convolution_descriptor, filter_descriptor, &filter_data,
+ scratch_allocator, algorithm_config, profile_result)
+ .ok();
+ break;
+ case ConvolutionKind::kBackwardInput:
+ launch_ok = stream
+ ->ThenConvolveBackwardDataWithAlgorithm(
+ filter_descriptor, filter_data, output_descriptor,
+ output_data, convolution_descriptor, input_descriptor,
+ &input_data, scratch_allocator, algorithm_config,
+ profile_result)
+ .ok();
+ break;
+ case ConvolutionKind::kForward:
+ launch_ok =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, &output_data,
+ scratch_allocator, algorithm_config, profile_result)
+ .ok();
+ break;
+ }
+ if (launch_ok) {
+ return tensorflow::Status::OK();
+ }
+ return InternalError(
+ "Unable to launch convolution for thunk %p with type %s and algorithm "
+ "(%lld, %lld)",
+ this, ConvolutionKindToString(convolution_kind_).c_str(),
+ algorithm_config.algorithm(), algorithm_config.algorithm_no_scratch());
+}
+
+std::vector<se::dnn::AlgorithmType> ConvolutionThunk::GetAlgorithms(
+ se::StreamExecutor* stream_exec) const {
+ std::vector<se::dnn::AlgorithmType> algorithms;
+ switch (convolution_kind_) {
+ case ConvolutionKind::kBackwardFilter:
+ CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms(&algorithms));
+ break;
+ case ConvolutionKind::kBackwardInput:
+ CHECK(stream_exec->GetConvolveBackwardDataAlgorithms(&algorithms));
+ break;
+ case ConvolutionKind::kForward:
+ CHECK(stream_exec->GetConvolveAlgorithms(&algorithms));
+ break;
+ }
+ return algorithms;
+}
+
+tensorflow::Status ConvolutionThunk::ConvolveWithTune(
+ const BatchDescriptor& input_descriptor, se::DeviceMemory<float> input_data,
+ const FilterDescriptor& filter_descriptor,
+ se::DeviceMemory<float> filter_data,
+ const BatchDescriptor& output_descriptor,
+ se::DeviceMemory<float> output_data,
+ const ConvolutionDescriptor& convolution_descriptor,
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out.
+ legacy_flags::ConvolutionThunkFlags* flags =
+ legacy_flags::GetConvolutionThunkFlags();
+ if (flags->xla_gpu_autotune_convolution_algorithm &&
+ best_algorithm_.algorithm() == se::dnn::kDefaultAlgorithm) {
+ // Auto-tuning either is disabled or only happens in the first run of this
+ // function.
+ VLOG(2) << "Profiling for best convolution algorithm used for "
+ "ConvolutionThunk: "
+ << this;
+
+ se::dnn::ProfileResult best_result;
+ se::dnn::ProfileResult best_result_without_scratch;
+ for (se::dnn::AlgorithmType algorithm : GetAlgorithms(stream->parent())) {
+ ConvolveScratchAllocator scratch_allocator(
+ buffer_allocations.device_ordinal(),
+ buffer_allocations.memory_allocator());
+ se::dnn::ProfileResult profile_result;
+ bool launch_ok =
+ Convolve(input_descriptor, input_data, filter_descriptor, filter_data,
+ output_descriptor, output_data, convolution_descriptor,
+ se::dnn::AlgorithmConfig(algorithm, algorithm), stream,
+ &scratch_allocator, &profile_result)
+ .ok();
+ if (launch_ok && profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalAllocatedBytes() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_without_scratch.elapsed_time_in_ms()) {
+ best_result_without_scratch = profile_result;
+ }
+ }
+ }
+
+ if (best_result.is_valid()) {
+ best_algorithm_.set_algorithm(best_result.algorithm());
+ } else {
+ LOG(ERROR) << "No convolution algorithm works with profiling. Fall back "
+ "to the default algorithm.";
+ best_algorithm_.set_algorithm(se::dnn::kDefaultAlgorithm);
+ }
+
+ if (best_result_without_scratch.is_valid()) {
+ best_algorithm_.set_algorithm_no_scratch(
+ best_result_without_scratch.algorithm());
+ } else {
+ LOG(ERROR) << "No convolution algorithm without scratch works with "
+ "profiling. Fall back "
+ "to the default algorithm.";
+ best_algorithm_.set_algorithm_no_scratch(se::dnn::kDefaultAlgorithm);
+ }
+ }
+
+ {
+ VLOG(2) << "Using convolution algorithm (" << best_algorithm_.algorithm()
+ << ", " << best_algorithm_.algorithm_no_scratch()
+ << ") for ConvolutionThunk: " << this;
+ ConvolveScratchAllocator scratch_allocator(
+ buffer_allocations.device_ordinal(),
+ buffer_allocations.memory_allocator());
+ return Convolve(input_descriptor, input_data, filter_descriptor,
+ filter_data, output_descriptor, output_data,
+ convolution_descriptor, best_algorithm_, stream,
+ &scratch_allocator, nullptr);
+ }
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
new file mode 100644
index 0000000000..cd9568f6a2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h
@@ -0,0 +1,149 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A one-time scratch allocator for forward and backward convolution. The
+// scratch buffers allocated are released on destruction.
+//
+// Not thread-safe.
+class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator {
+ public:
+ ConvolveScratchAllocator(int device_ordinal,
+ DeviceMemoryAllocator* memory_allocator);
+
+ ~ConvolveScratchAllocator() override;
+
+ int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override;
+
+ int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
+
+ perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
+ AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override;
+
+ private:
+ const int device_ordinal_;
+ DeviceMemoryAllocator* memory_allocator_;
+ std::vector<perftools::gputools::DeviceMemoryBase> allocated_buffers_;
+ int64 total_allocated_bytes_ = 0;
+};
+
+// This class stores everything that StreamExecutor needs to launch a BNN
+// convolution. It is generated by IrEmitter.
+//
+// This is thread-compatible.
+class ConvolutionThunk : public Thunk {
+ public:
+ // ConvolutionThunk performs one of the following types of convolution.
+ enum class ConvolutionKind {
+ kBackwardFilter, // Backward convolution for filter.
+ kBackwardInput, // Backward convolution for input.
+ kForward, // Forward convolution.
+ };
+
+ // Constructs a thunk for launching a DNN convolution.
+ // Semantics of null hlo_instruction argument are as in Thunk.
+ ConvolutionThunk(ConvolutionKind convolution_kind,
+ BufferAllocation::Index input_buffer,
+ BufferAllocation::Index filter_buffer,
+ BufferAllocation::Index output_buffer,
+ const Shape& input_shape, const Shape& filter_shape,
+ const Shape& output_shape, const Window& window,
+ const ConvolutionDimensionNumbers& dnums,
+ const HloInstruction* hlo);
+
+ ConvolutionThunk(const ConvolutionThunk&) = delete;
+ ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
+
+ // Does the convolution for the thunk on "stream". If the
+ // xla_gpu_autotune_convolution_algorithm is turned on, auto-tuning happens on
+ // the first run of this function.
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ tensorflow::Status ConvolveWithTune(
+ const perftools::gputools::dnn::BatchDescriptor& input_descriptor,
+ perftools::gputools::DeviceMemory<float> input_data,
+ const perftools::gputools::dnn::FilterDescriptor& filter_descriptor,
+ perftools::gputools::DeviceMemory<float> filter_data,
+ const perftools::gputools::dnn::BatchDescriptor& output_descriptor,
+ perftools::gputools::DeviceMemory<float> output_data,
+ const perftools::gputools::dnn::ConvolutionDescriptor&
+ convolution_descriptor,
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream);
+
+ tensorflow::Status Convolve(
+ const perftools::gputools::dnn::BatchDescriptor& input_descriptor,
+ perftools::gputools::DeviceMemory<float> input_data,
+ const perftools::gputools::dnn::FilterDescriptor& filter_descriptor,
+ perftools::gputools::DeviceMemory<float> filter_data,
+ const perftools::gputools::dnn::BatchDescriptor& output_descriptor,
+ perftools::gputools::DeviceMemory<float> output_data,
+ const perftools::gputools::dnn::ConvolutionDescriptor&
+ convolution_descriptor,
+ const perftools::gputools::dnn::AlgorithmConfig& algorithm_config,
+ perftools::gputools::Stream* stream,
+ ConvolveScratchAllocator* scratch_allocator,
+ perftools::gputools::dnn::ProfileResult* profile_result);
+
+ // Returns the convolve algorithms that can be used for this ConvolutionThunk.
+ std::vector<perftools::gputools::dnn::AlgorithmType> GetAlgorithms(
+ perftools::gputools::StreamExecutor* stream_exec) const;
+
+ // Fastest cuDNN convolution algorithm for this thunk learned from
+ // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set
+ // to the default value indicating cuDNN's convolution will choose
+ // the best algorithm from some heuristics based on its parameters.
+ perftools::gputools::dnn::AlgorithmConfig best_algorithm_;
+
+ ConvolutionKind convolution_kind_;
+
+ BufferAllocation::Index input_buffer_;
+ BufferAllocation::Index filter_buffer_;
+ BufferAllocation::Index output_buffer_;
+
+ Shape input_shape_;
+ Shape filter_shape_;
+ Shape output_shape_;
+
+ Window window_;
+
+ ConvolutionDimensionNumbers dim_nums_;
+};
+
+string ConvolutionKindToString(
+ ConvolutionThunk::ConvolutionKind convolution_kind);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc
new file mode 100644
index 0000000000..926690c5c0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.cc
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h"
+
+#include <memory>
+#include <set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(bool changed, CopyInsertion::Run(module));
+
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+
+ // Make sure all operands of a library call are in memory instead of constants
+ // in IR. The top-level (index {}) of the points-to set of each operand
+ // indicates the source(s) of the array buffer. If any of these are constant,
+ // then add a copy to materialize the array.
+ HloComputation* computation = module->entry_computation();
+ for (HloInstruction* hlo : computation->MakeInstructionPostOrder()) {
+ if (ImplementedAsLibraryCall(*hlo)) {
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ HloInstruction* operand = hlo->mutable_operand(i);
+ const PointsToSet& points_to =
+ points_to_analysis->GetPointsToSet(operand);
+ const auto& element = points_to.element(/*index=*/{});
+ if (std::any_of(element.begin(), element.end(),
+ [](const LogicalBuffer* buffer_source) {
+ return buffer_source->instruction()->opcode() ==
+ HloOpcode::kConstant;
+ })) {
+ TF_ASSIGN_OR_RETURN(HloInstruction * copy,
+ CopyInsertion::FindOrInsertCopy(operand));
+ hlo->ReplaceOperandWith(i, copy);
+ changed = true;
+ }
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/copy_insertion.h b/tensorflow/compiler/xla/service/gpu/copy_insertion.h
new file mode 100644
index 0000000000..11077dad2e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/copy_insertion.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
+
+#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+namespace xla {
+namespace gpu {
+
+// Besides the modifications made by the generic xla::CopyInsertion, this
+// GPU-specific copy insertion also materializes operands of library calls by
+// inserting kCopy instructions.
+class GpuCopyInsertion : public CopyInsertion {
+ public:
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
new file mode 100644
index 0000000000..76fb079bd4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
+
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+CopyThunk::CopyThunk(const void* source_address,
+ BufferAllocation::Index destination_buffer,
+ uint64 mem_size, const HloInstruction* hlo_instruction)
+ : Thunk(Kind::kCopy, hlo_instruction),
+ source_address_(source_address),
+ destination_buffer_(destination_buffer),
+ mem_size_(mem_size) {}
+
+tensorflow::Status CopyThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+ perftools::gputools::DeviceMemoryBase destination_data =
+ buffer_allocations.GetDeviceAddress(destination_buffer_);
+ stream->ThenMemcpy(&destination_data, source_address_, mem_size_);
+ return tensorflow::Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
new file mode 100644
index 0000000000..803e699bfd
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+// A thunk that copies data. For now, it copies data only from host to device.
+// But it can be extended to perform device-to-host or intra-device copying.
+class CopyThunk : public Thunk {
+ public:
+ // Constructs a CopyThunk that copies host data from `source_address` to the
+ // device buffer `destination_buffer`. `mem_size` is the size of the data in
+ // bytes.
+ CopyThunk(const void* source_address,
+ BufferAllocation::Index destination_buffer, uint64 mem_size,
+ const HloInstruction* hlo_instruction);
+
+ CopyThunk(const CopyThunk&) = delete;
+ CopyThunk& operator=(const CopyThunk&) = delete;
+
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ const void* source_address_;
+ BufferAllocation::Index destination_buffer_;
+ uint64 mem_size_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_COPY_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
new file mode 100644
index 0000000000..e318ade5ee
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -0,0 +1,396 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
+
+#include <stddef.h>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+// IWYU pragma: no_include "llvm/IR/Attributes.gen.inc"
+// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "external/llvm/include/llvm/ADT/APInt.h"
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Intrinsics.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IR/Type.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+namespace gpu {
+
+using llvm_ir::IrArray;
+using llvm_ir::SetToFirstInsertPoint;
+
+GpuElementalIrEmitter::GpuElementalIrEmitter(
+ const HloModuleConfig& hlo_module_config, llvm::Module* module,
+ llvm::IRBuilder<>* ir_builder, NestedComputer compute_nested)
+ : ElementalIrEmitter(hlo_module_config, module, ir_builder),
+ compute_nested_(std::move(compute_nested)) {}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitMathCall(
+ const string& callee_name,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
+ PrimitiveType output_type) const {
+ // Binary math functions tranform are of type [T] -> T.
+ for (PrimitiveType input_type : input_types) {
+ if (output_type != input_type) {
+ return Unimplemented("Input type ≠ output type: %s ≠ %s",
+ PrimitiveType_Name(input_type).c_str(),
+ PrimitiveType_Name(output_type).c_str());
+ }
+ }
+
+ // The libdevice math functions differentiate between "double" and "float" by
+ // appending an 'f' to the function's name.
+ string function_name = callee_name;
+ switch (output_type) {
+ case F32:
+ function_name += 'f';
+ break;
+ case F64:
+ break;
+ default:
+ return Unimplemented("Bad type for math call: %s",
+ PrimitiveType_Name(output_type).c_str());
+ }
+
+ return EmitDeviceFunctionCall(
+ function_name, operands, input_types, output_type,
+ {llvm::Attribute::ReadNone, llvm::Attribute::NoUnwind});
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const {
+ PrimitiveType lhs_input_type = op->operand(0)->shape().element_type();
+ PrimitiveType rhs_input_type = op->operand(1)->shape().element_type();
+ PrimitiveType output_type = op->shape().element_type();
+ switch (op->opcode()) {
+ case HloOpcode::kRemainder: {
+ return EmitMathCall("__nv_fmod", {lhs_value, rhs_value},
+ {lhs_input_type, rhs_input_type}, output_type);
+ }
+ case HloOpcode::kPower: {
+ return EmitMathCall("__nv_pow", {lhs_value, rhs_value},
+ {lhs_input_type, rhs_input_type}, output_type);
+ }
+ default:
+ return ElementalIrEmitter::EmitFloatBinaryOp(op, lhs_value, rhs_value);
+ }
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitErfcInv(
+ PrimitiveType prim_type, llvm::Value* value) const {
+ return EmitMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type);
+}
+
+StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const {
+ PrimitiveType input_type = op->operand(0)->shape().element_type();
+ PrimitiveType output_type = op->shape().element_type();
+ switch (op->opcode()) {
+ case HloOpcode::kExp:
+ return EmitMathCall("__nv_exp", {operand_value}, {input_type},
+ output_type);
+ case HloOpcode::kFloor:
+ return EmitMathCall("__nv_floor", {operand_value}, {input_type},
+ output_type);
+ case HloOpcode::kCeil:
+ return EmitMathCall("__nv_ceil", {operand_value}, {input_type},
+ output_type);
+ case HloOpcode::kLog:
+ return EmitMathCall("__nv_log", {operand_value}, {input_type},
+ output_type);
+ case HloOpcode::kTanh:
+ return EmitMathCall("__nv_tanh", {operand_value}, {input_type},
+ output_type);
+ default:
+ return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value);
+ }
+}
+
+llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
+ const string& callee_name,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
+ PrimitiveType output_type,
+ tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const {
+ std::vector<llvm::Type*> ir_input_types;
+ for (PrimitiveType input_type : input_types) {
+ ir_input_types.push_back(
+ llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_));
+ }
+ llvm::FunctionType* callee_type = llvm::FunctionType::get(
+ llvm_ir::PrimitiveTypeToIrType(output_type,
+ ir_builder_), // The return type.
+ ir_input_types, // The parameter types.
+ false); // No variadic arguments.
+
+ // Declares the callee if it is not declared already.
+ llvm::Function* callee = llvm::cast<llvm::Function>(
+ ir_builder_->GetInsertBlock()->getModule()->getOrInsertFunction(
+ llvm_ir::AsStringRef(callee_name), callee_type));
+
+ for (auto attribute : attributes) {
+ callee->addFnAttr(attribute);
+ }
+
+ return ir_builder_->CreateCall(callee, llvm_ir::AsArrayRef(operands));
+}
+
+llvm::Value* GpuElementalIrEmitter::EmitThreadId() const {
+ llvm::Value* block_id = ir_builder_->CreateIntCast(
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
+ {}, {}, ir_builder_),
+ ir_builder_->getIntNTy(128), /*isSigned=*/true, "block.id");
+ llvm::Value* thread_id_in_block = ir_builder_->CreateIntCast(
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
+ {}, {}, ir_builder_),
+ ir_builder_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+ llvm::Value* threads_per_block = ir_builder_->CreateIntCast(
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
+ {}, {}, ir_builder_),
+ ir_builder_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+ return ir_builder_->CreateNSWAdd(
+ ir_builder_->CreateNSWMul(block_id, threads_per_block),
+ thread_id_in_block);
+}
+
+llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const {
+ switch (hlo->opcode()) {
+ case HloOpcode::kPad:
+ return [=, &operand_to_generator](
+ const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
+ auto index = padded_index;
+ llvm::Value* in_bounds =
+ llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1);
+ for (int i = 0; i < index.size(); ++i) {
+ auto index_typed_const = [=](int64 n) {
+ return llvm::ConstantInt::get(index[i]->getType(), n);
+ };
+ const auto& pad_dim = hlo->padding_config().dimensions(i);
+ index[i] = ir_builder_->CreateSub(
+ index[i], index_typed_const(pad_dim.edge_padding_low()));
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds,
+ ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
+ "in_bounds");
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds,
+ ir_builder_->CreateICmpEQ(
+ index_typed_const(0),
+ ir_builder_->CreateURem(
+ index[i],
+ index_typed_const(pad_dim.interior_padding() + 1))),
+ "in_bounds");
+ index[i] = ir_builder_->CreateSDiv(
+ index[i], index_typed_const(pad_dim.interior_padding() + 1));
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds,
+ ir_builder_->CreateICmpSLT(
+ index[i],
+ index_typed_const(hlo->operand(0)->shape().dimensions(i))),
+ "in_bounds");
+ }
+
+ // if (in_bounds) {
+ // ret_value = operand0[index]; // source
+ // } else {
+ // ret_value = *operand1; // padding
+ // }
+ llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
+ ir_builder_),
+ "pad_result_addr", ir_builder_);
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
+ operand_to_generator.at(hlo->operand(0))(index));
+ ir_builder_->CreateStore(operand_value, ret_value_addr);
+
+ SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
+ operand_to_generator.at(hlo->operand(1))({}));
+ ir_builder_->CreateStore(padding_value, ret_value_addr);
+
+ SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+ // Don't create phi(operand_value, padding_value) here, because invoking
+ // operand_to_generator may create new basic blocks, making the parent
+ // of operand_value or padding_value no longer a predecessor of
+ // if_data.after_block.
+ return ir_builder_->CreateLoad(ret_value_addr);
+ };
+ case HloOpcode::kMap:
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ TF_RET_CHECK(!hlo->operands().empty())
+ << "Zero operand map not implemented in GPU backend.";
+ TF_RET_CHECK(hlo->to_apply()->num_parameters() > 0);
+ std::vector<llvm::Value*> operand_elements;
+ for (HloInstruction* operand : hlo->operands()) {
+ TF_ASSIGN_OR_RETURN(llvm::Value * value,
+ operand_to_generator.at(operand)(index));
+ operand_elements.push_back(value);
+ }
+ return compute_nested_(*hlo->to_apply(), operand_elements);
+ };
+ case HloOpcode::kReduceWindow:
+ // Pseudocode:
+ // for each index I in output
+ // value = init_value
+ // for each index W in window
+ // for each dimension i from 0 to rank - 1
+ // (input index I)[i] = O[i] * stride[i] + W[i] - pad_low[i]
+ // if I in bounds of input
+ // value = function(value, input[I])
+ // output[O] = value
+ return [=, &operand_to_generator](
+ const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ const HloInstruction* operand = hlo->operand(0);
+ const Window& window = hlo->window();
+
+ // TODO(b/31410564): Implement dilation for reduce-window.
+ if (window_util::HasDilation(window)) {
+ return Unimplemented(
+ "Dilation for reduce-window not implemented on GPU. "
+ "See b/31410564.");
+ }
+
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_),
+ "reduce_window_accum_ptr", ir_builder_);
+ {
+ TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
+ operand_to_generator.at(hlo->operand(1))({}));
+ ir_builder_->CreateStore(init_value, accum_ptr);
+ }
+
+ llvm_ir::ForLoopNest loops(ir_builder_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ }
+ const IrArray::Index window_index = loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ CHECK_EQ(window_index.size(), index.size());
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
+
+ IrArray::Index input_index(index.size());
+ llvm::Value* in_bounds = ir_builder_->getInt1(1);
+ for (size_t i = 0; i < index.size(); ++i) {
+ llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
+ index[i], ir_builder_->getInt64(window.dimensions(i).stride()));
+ input_index[i] = ir_builder_->CreateNSWSub(
+ ir_builder_->CreateNSWAdd(stridden_index, window_index[i]),
+ ir_builder_->getInt64(window.dimensions(i).padding_low()));
+
+ // We must check whether 0 ≤ input_index[i] < bound, as otherwise
+ // we are in the pad and so can skip the computation. This
+ // comparison is equivalent to the unsigned comparison
+ // input_index[i] < bound, as a negative value wraps to a large
+ // positive value.
+ in_bounds = ir_builder_->CreateAnd(
+ in_bounds,
+ ir_builder_->CreateICmpULT(
+ input_index[i],
+ ir_builder_->getInt64(operand->shape().dimensions(i))));
+ }
+
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+
+ // We are not in pad, so do the computation.
+ TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
+ operand_to_generator.at(operand)(input_index));
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * accum_value,
+ compute_nested_(*hlo->to_apply(),
+ {ir_builder_->CreateLoad(accum_ptr), input_value}));
+ ir_builder_->CreateStore(accum_value, accum_ptr);
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder_);
+ return ir_builder_->CreateLoad(accum_ptr);
+ };
+ case HloOpcode::kReduce:
+ return [=, &operand_to_generator](
+ const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
+ const HloInstruction* operand = hlo->operand(0);
+ llvm::Value* accum_ptr =
+ ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ hlo->shape().element_type(), ir_builder()));
+ TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
+ operand_to_generator.at(hlo->operand(1))({}));
+ ir_builder()->CreateStore(init_value, accum_ptr);
+
+ llvm_ir::ForLoopNest loops(ir_builder_);
+ IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions(
+ operand->shape(), hlo->dimensions(), "reduction_dim");
+ if (!ShapeUtil::IsScalar(hlo->shape())) {
+ // Here only input_index[hlo->dimensions()] are non-null, so we must
+ // set the rest.
+ size_t j = 0;
+ for (size_t i = 0; i < input_index.size(); ++i) {
+ if (input_index[i] == nullptr) {
+ input_index[i] = output_index[j++];
+ }
+ }
+ CHECK_EQ(output_index.size(), j);
+ }
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder());
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * input_value,
+ operand_to_generator.at(hlo->operand(0))(input_index));
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * accum_value,
+ compute_nested_(
+ *hlo->to_apply(),
+ {ir_builder()->CreateLoad(accum_ptr), input_value}));
+ ir_builder()->CreateStore(accum_value, accum_ptr);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder());
+ return ir_builder()->CreateLoad(accum_ptr);
+ };
+ default:
+ return ElementalIrEmitter::MakeElementGenerator(hlo,
+ operand_to_generator);
+ }
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
new file mode 100644
index 0000000000..8e5bdc59cf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_
+
+#include <functional>
+#include <string>
+#include <utility>
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuElementalIrEmitter : public ElementalIrEmitter {
+ public:
+ // A NestedComputer computes an element of the output of the given computation
+ // given an ArraySlice of its input elements.
+ using NestedComputer = std::function<StatusOr<llvm::Value*>(
+ const HloComputation&, tensorflow::gtl::ArraySlice<llvm::Value*>)>;
+
+ GpuElementalIrEmitter(const HloModuleConfig& hlo_module_config,
+ llvm::Module* module, llvm::IRBuilder<>* ir_builder,
+ NestedComputer compute_nested);
+
+ llvm_ir::ElementGenerator MakeElementGenerator(
+ const HloInstruction* hlo,
+ const HloToElementGeneratorMap& operand_to_generator) const override;
+
+ protected:
+ StatusOr<llvm::Value*> EmitFloatUnaryOp(
+ const HloInstruction* op, llvm::Value* operand_value) const override;
+
+ StatusOr<llvm::Value*> EmitFloatBinaryOp(
+ const HloInstruction* op, llvm::Value* lhs_value,
+ llvm::Value* rhs_value) const override;
+
+ StatusOr<llvm::Value*> EmitErfcInv(PrimitiveType prim_type,
+ llvm::Value* value) const override;
+
+ llvm::Value* EmitThreadId() const override;
+
+ private:
+ // Emit IR to call a device function named "callee_name" on the given
+ // operand. Returns the IR value that represents the return value.
+ llvm::Value* EmitDeviceFunctionCall(
+ const string& callee_name,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_type,
+ PrimitiveType output_type,
+ tensorflow::gtl::ArraySlice<llvm::Attribute::AttrKind> attributes) const;
+
+ // Emit IR to call a device function of type [T] -> T. It adjusts the
+ // callee_name to account for float/double types.
+ // Returns the IR value that represents the return value.
+ StatusOr<llvm::Value*> EmitMathCall(
+ const string& callee_name,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<PrimitiveType> input_types,
+ PrimitiveType output_type) const;
+
+ NestedComputer compute_nested_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_ELEMENTAL_IR_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.cc b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
new file mode 100644
index 0000000000..283d21ca22
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.cc
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+ForThunk::ForThunk(const int64 loop_limit,
+ std::unique_ptr<ThunkSequence> body_thunk_sequence,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kWhile, hlo),
+ loop_limit_(loop_limit),
+ body_thunk_sequence_(
+ MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+
+tensorflow::Status ForThunk::Initialize(const GpuExecutable& executable) {
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ForThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+ for (int64 i = 0; i < loop_limit_; ++i) {
+ // Invoke loop body thunk sequence.
+ TF_RETURN_IF_ERROR(
+ body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h
new file mode 100644
index 0000000000..525a2af941
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// ForThunk executes 'loop_limit' invocations of 'body_thunk_sequence'.
+class ForThunk : public Thunk {
+ public:
+ ForThunk(const int64 loop_limit,
+ std::unique_ptr<ThunkSequence> body_thunk_sequence,
+ const HloInstruction* hlo);
+ ForThunk(const ForThunk&) = delete;
+ ForThunk& operator=(const ForThunk&) = delete;
+
+ tensorflow::Status Initialize(const GpuExecutable& executable) override;
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ const int64 loop_limit_;
+ std::unique_ptr<SequentialThunk> body_thunk_sequence_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
new file mode 100644
index 0000000000..98a8a4a2b1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -0,0 +1,189 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
+
+#include <functional>
+
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+using Index = BufferAllocation::Index;
+
+namespace {
+
+// This struct contains the metadata of a matrix, e.g., its base address and
+// dimensions.
+struct MatrixDescriptor {
+ MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
+ int64 matrix_num_rows, int64 matrix_num_cols)
+ : data(matrix_data),
+ transpose(needs_transpose),
+ num_rows(matrix_num_rows),
+ num_cols(matrix_num_cols) {}
+
+ se::DeviceMemoryBase data;
+ bool transpose; // Whether this matrix needs to be transposed.
+ int64 num_rows;
+ int64 num_cols;
+};
+
+// Performs a gemm call on lhs_matrix and rhs_matrix and stores the result to
+// output_matrix.
+template <typename Element>
+tensorflow::Status DoGemm(MatrixDescriptor lhs_matrix,
+ MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix, se::Stream* stream) {
+ DCHECK(!output_matrix.transpose);
+
+ se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
+ se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
+ se::DeviceMemory<Element> output_data(output_matrix.data);
+
+ bool launch_ok =
+ stream
+ ->ThenBlasGemm(
+ lhs_matrix.transpose ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose,
+ rhs_matrix.transpose ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose,
+ output_matrix.num_rows, output_matrix.num_cols,
+ lhs_matrix.transpose
+ ? lhs_matrix.num_rows
+ : lhs_matrix.num_cols, // Size of the reduce dimension.
+ /*alpha=*/1.0,
+ lhs_data,
+ lhs_matrix.num_rows, // The leading dimension of LHS.
+ rhs_data,
+ rhs_matrix.num_rows, // The leading dimension of RHS.
+ /*beta=*/0.0, &output_data,
+ output_matrix
+ .num_rows) // The leading dimension of the output matrix.
+ .ok();
+ if (!launch_ok) {
+ return InternalError("Unable to launch cuBLAS gemm on stream %p", stream);
+ }
+ return tensorflow::Status::OK();
+}
+
+// Return, if the given type is a valid Gemm elemental type, the executor for
+// that type, else null.
+// TODO(b/27202055): consider more element types.
+std::function<tensorflow::Status(MatrixDescriptor, MatrixDescriptor,
+ MatrixDescriptor, se::Stream*)>
+FindGemmExecutor(PrimitiveType type) {
+ switch (type) {
+ case F32:
+ return &DoGemm<float>;
+ case F64:
+ return &DoGemm<double>;
+ default:
+ return nullptr;
+ }
+}
+
+} // namespace
+
+GemmThunk::GemmThunk(Index lhs_buffer, Index rhs_buffer, Index output_buffer,
+ const Shape& lhs_shape, const Shape& rhs_shape,
+ const Shape& output_shape, bool transpose_lhs,
+ bool transpose_rhs, const HloInstruction* hlo_instruction)
+ : Thunk(Kind::kGemm, hlo_instruction),
+ lhs_buffer_(lhs_buffer),
+ rhs_buffer_(rhs_buffer),
+ output_buffer_(output_buffer),
+ lhs_shape_(lhs_shape),
+ rhs_shape_(rhs_shape),
+ output_shape_(output_shape),
+ transpose_lhs_(transpose_lhs),
+ transpose_rhs_(transpose_rhs) {}
+
+tensorflow::Status GemmThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ VLOG(2) << "Executing a GemmThunk";
+ auto executor = FindGemmExecutor(output_shape_.element_type());
+ DCHECK(executor != nullptr);
+
+ se::DeviceMemoryBase lhs_data =
+ buffer_allocations.GetDeviceAddress(lhs_buffer_);
+ se::DeviceMemoryBase rhs_data =
+ buffer_allocations.GetDeviceAddress(rhs_buffer_);
+ se::DeviceMemoryBase output_data =
+ buffer_allocations.GetDeviceAddress(output_buffer_);
+
+ // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
+ // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
+ // their layout. Therefore, we should treat dimension 0 as row and dimension 1
+ // as column when mapping a matrix Dot to BLAS gemm.
+ int64 output_num_rows = output_shape_.dimensions(0);
+ int64 output_num_cols = output_shape_.dimensions(1);
+
+ // BLAS gemm expects the inputs and the output are in column-major order.
+ // Therefore, we need to convert dot between row-major matrices to that
+ // between column-major matrices. The key insight for the conversion is that,
+ // in linear storage, matrix M in column-major order is identical to the
+ // tranpose of M in row-major order. In other words,
+ //
+ // column-major(M) = row-major(M^T).
+ //
+ // Leveraging this insight, we can perform dot between row-major matrices as
+ // follows.
+ //
+ // row-major(C)
+ // = row-major(A x B) = column-major((A x B)^T) = column-major(B^T x A^T)
+ // = gemm(column-major(B^T), column-major(A^T))
+ // = gemm(row-major(B), row-major(A))
+ //
+ // Although we do not modify the content of A and B in linear memory, we
+ // should use the dimensions of B^T and A^T when calling gemm. For example,
+ // the leading dimension of the LHS matrix of gemm is the number of rows in
+ // B^T and thus the number of columns in B.
+
+ auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape,
+ bool transpose) -> MatrixDescriptor {
+ bool is_row_major = shape.layout().minor_to_major(0) != 0;
+ bool layout_mismatch = shape.layout().minor_to_major(0) !=
+ output_shape_.layout().minor_to_major(0);
+ return MatrixDescriptor(data, transpose ^ layout_mismatch,
+ shape.dimensions(is_row_major),
+ shape.dimensions(!is_row_major));
+ };
+
+ const MatrixDescriptor lhs_descriptor =
+ make_descriptor(lhs_data, lhs_shape_, transpose_lhs_);
+ const MatrixDescriptor rhs_descriptor =
+ make_descriptor(rhs_data, rhs_shape_, transpose_rhs_);
+ if (output_shape_.layout().minor_to_major(0) == 0) {
+ return executor(
+ lhs_descriptor, rhs_descriptor,
+ MatrixDescriptor(output_data, false, output_num_rows, output_num_cols),
+ stream);
+ } else {
+ return executor(
+ rhs_descriptor, lhs_descriptor,
+ MatrixDescriptor(output_data, false, output_num_cols, output_num_rows),
+ stream);
+ }
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
new file mode 100644
index 0000000000..7c8574d275
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// This class stores everything that StreamExecutor needs to launch a BLAS gemm.
+// It is generated by IrEmitter.
+//
+// This is thread-compatible.
+class GemmThunk : public Thunk {
+ public:
+ // Constructs a thunk that computes "output = lhs <dot> rhs" using BLAS gemm.
+ // transpose_lhs and transpose_rhs indicate whether gemm should transpose the
+ // lhs and rhs operand. hlo_instruction is as in Thunk.
+ GemmThunk(BufferAllocation::Index lhs_buffer,
+ BufferAllocation::Index rhs_buffer,
+ BufferAllocation::Index output_buffer, const Shape& lhs_shape,
+ const Shape& rhs_shape, const Shape& output_shape,
+ bool transpose_lhs, bool transpose_rhs,
+ const HloInstruction* hlo_instruction);
+
+ GemmThunk(const GemmThunk&) = delete;
+ GemmThunk& operator=(const GemmThunk&) = delete;
+
+ // Does the gemm operation for the thunk on "stream", which must be non-null.
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ BufferAllocation::Index lhs_buffer_;
+ BufferAllocation::Index rhs_buffer_;
+ BufferAllocation::Index output_buffer_;
+
+ Shape lhs_shape_;
+ Shape rhs_shape_;
+ Shape output_shape_;
+
+ bool transpose_lhs_;
+ bool transpose_rhs_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
new file mode 100644
index 0000000000..a13279c6ff
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -0,0 +1,335 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h"
+
+#include <stdlib.h>
+#include <functional>
+#include <utility>
+
+#include "external/llvm/include/llvm/IR/DiagnosticInfo.h"
+#include "external/llvm/include/llvm/IR/DiagnosticPrinter.h"
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h"
+#include "tensorflow/compiler/xla/service/gpu/copy_insertion.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+#include "tensorflow/compiler/xla/service/transpose_folding.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/cuda_libdevice_path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/subprocess.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// The triple that represents our target.
+const char* kTargetTriple = "nvptx64-nvidia-cuda";
+
+// The data layout of the emitted module. Copied from computeDataLayout in
+// NVPTXTargetMachine.cpp.
+const char* kDataLayout = "e-i64:64-v16:16-v32:32-n16:32:64";
+
+// Returns the directory containing nvvm libdevice files. This function is
+// called in GpuCompiler's constructor, so can't return an error. But
+// GpuCompiler::Compile will return an error when the wanted libdevice file
+// doesn't exist in the folder this function returns.
+string GetLibdeviceDir() {
+ std::vector<string> potential_libdevice_dirs;
+ // Flag xla_cuda_data_dir specified by the user.
+ legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags();
+ const string datadir = flags->xla_cuda_data_dir;
+ if (!datadir.empty()) {
+ potential_libdevice_dirs.push_back(datadir);
+ }
+ potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot());
+
+ // Tries all potential libdevice directories in the order they are inserted.
+ // Returns the first directory that exists in the file system.
+ for (const string& potential_libdevice_dir : potential_libdevice_dirs) {
+ if (tensorflow::Env::Default()->IsDirectory(potential_libdevice_dir).ok()) {
+ VLOG(2) << "Found libdevice dir " << potential_libdevice_dir;
+ return potential_libdevice_dir;
+ }
+ VLOG(2) << "Unable to find potential libdevice dir "
+ << potential_libdevice_dir;
+ }
+
+ // Last resort: maybe in the current folder.
+ return ".";
+}
+
+// Runs optimization passes on the given HLO module.
+tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
+ const Compiler::HloDumper& dump_hlo,
+ const se::DeviceDescription& device_desc) {
+ {
+ HloPassPipeline pipeline("optimization", dump_hlo);
+ {
+ auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
+ "simplification", dump_hlo);
+ pass.AddPass<AlgebraicSimplifier>(
+ /*is_layout_sensitive=*/false,
+ [](const Shape&, const Shape&) { return false; });
+ pass.AddPass<ReshapeMover>();
+ }
+ pipeline.AddPass<ConvolutionFolding>();
+ pipeline.AddPass<TransposeFolding>(ImplementedAsGemm);
+ pipeline.AddPass<HloSubcomputationUnification>();
+ pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
+ pipeline.AddPass<HloDCE>();
+ TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status());
+ }
+ {
+ HloPassFix<HloPassPipeline> fusion("fusion", dump_hlo);
+ fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
+ fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
+ return fusion.Run(hlo_module).status();
+ }
+}
+
+// Modifies the given HLO module so that it will be accepted by IrEmitter.
+// Unlike optimization passes, the passes are necessary for correctness.
+tensorflow::Status PrepareHloModuleForIrEmitting(
+ const Compiler::HloDumper& dump_hlo, HloModule* hlo_module,
+ HloModuleConfig* module_config) {
+ // In some cases, we have to place the result of an instruction in a temporary
+ // buffer. For instance, the buffer that holds an external parameter is
+ // assumed immutable at this point, and should not be reused for output
+ // (b/27180329). Therefore, in that case, we set the output to be a copy of
+ // the parameter.
+ HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo);
+ pipeline.AddPass<PadInsertion>();
+ pipeline.AddPass<GpuLayoutAssignment>(
+ module_config->mutable_entry_computation_layout());
+ // The LayoutAssignment pass may leave behind kCopy instructions which are
+ // duplicate or NOPs, so remove them with algebraic simplification and CSE.
+ pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(
+ /*is_layout_sensitive=*/true,
+ [](const Shape&, const Shape&) { return true; });
+ pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+ // Copy insertion should be performed immediately before IR emission to avoid
+ // inserting unnecessary copies (later pass adds an instruction which
+ // materializes the value) or missing a necessary copy (later pass removes an
+ // instruction which materializes a value).
+ pipeline.AddPass<GpuCopyInsertion>();
+ pipeline.AddPass<HloDCE>();
+ return pipeline.Run(hlo_module).status();
+}
+
+// Invokes the ptxas tool on the given PTX string, and dumps its output.
+void DumpPtxasInfo(const string& ptx) {
+ legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags();
+ const string ptxas_path = flags->xla_ptxas_path;
+ // Do not log PTX stats if ptxas is not found at the given path.
+ if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) {
+ LOG(WARNING)
+ << "Failed to dump PTX stats because ptxas is not found at path \""
+ << ptxas_path << "\".";
+ return;
+ }
+
+ // Write `ptx` into a temporary file.
+ char tempdir_template[] = "/tmp/ptxXXXXXX";
+ char* tempdir_name = mkdtemp(tempdir_template);
+ CHECK_NOTNULL(tempdir_name);
+ string ptx_path = tensorflow::io::JoinPath(tempdir_name, "ptx");
+ TF_CHECK_OK(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), ptx_path, ptx));
+ LOG(INFO) << "ptx file written to: " << ptx_path;
+
+ // Invoke ptxas and collect its output.
+ tensorflow::SubProcess ptxas_info_dumper;
+ ptxas_info_dumper.SetProgram(ptxas_path, {ptxas_path, ptx_path, "-o",
+ "/dev/null", "-v", "-arch=sm_35"});
+ ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR,
+ tensorflow::ACTION_PIPE);
+ CHECK(ptxas_info_dumper.Start());
+ string stderr_output;
+ int exit_status = ptxas_info_dumper.Communicate(
+ /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output);
+ XLA_LOG_LINES(tensorflow::INFO, stderr_output);
+ if (exit_status != 0) {
+ LOG(FATAL) << "Invalid PTX. See the error message above for reasons.";
+ }
+}
+
+} // namespace
+
+GpuCompiler::GpuCompiler() : libdevice_dir_(GetLibdeviceDir()) {}
+
+StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ se::StreamExecutor* stream_exec) {
+ TF_RET_CHECK(stream_exec != nullptr);
+
+ TF_RETURN_IF_ERROR(OptimizeHloModule(hlo_module.get(), dump_hlo,
+ stream_exec->GetDeviceDescription()));
+ TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(dump_hlo, hlo_module.get(),
+ module_config.get()));
+
+ llvm::LLVMContext llvm_context;
+ std::string buffer;
+ llvm::raw_string_ostream error(buffer);
+ llvm::DiagnosticPrinterRawOStream printer(error);
+ auto DiagnosticHandler = [](const llvm::DiagnosticInfo& diag_info,
+ void* Context) {
+ auto printer = static_cast<llvm::DiagnosticPrinterRawOStream*>(Context);
+ diag_info.print(*printer);
+ };
+ llvm_context.setDiagnosticHandler(DiagnosticHandler, &printer);
+
+ llvm::Module llvm_module(hlo_module->name().c_str(), llvm_context);
+ // Set the target triple and the data layout.
+ llvm_module.setTargetTriple(kTargetTriple);
+ llvm_module.setDataLayout(kDataLayout);
+ const llvm::DataLayout& data_layout = llvm_module.getDataLayout();
+ int64 pointer_size = data_layout.getPointerSize();
+
+ // Determine the HLO schedule, which is an ordering of HLO instructions. This
+ // is used by buffer assignment to enable buffer reuse, and the same ordering
+ // must also be used to determine the thunk launch schedule.
+ std::unique_ptr<StreamAssignment> stream_assignment =
+ AssignStreams(*hlo_module);
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloSchedule> hlo_schedule,
+ HloSchedule::Build(*hlo_module, *stream_assignment));
+
+ // Run buffer analysis on the HLO graph. This analysis figures out which
+ // temporary buffers are required to run the computation.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferAssignment> buffer_assignment,
+ BufferAssigner::Run(hlo_module.get(), hlo_schedule->ConsumeHloOrdering(),
+ pointer_size));
+ auto temp_buffer_offsets = MakeUnique<TempBufferOffsets>(*buffer_assignment);
+
+ IrEmitterContext ir_emitter_context(
+ hlo_module.get(), buffer_assignment.get(), temp_buffer_offsets.get(),
+ &stream_exec->GetDeviceDescription(), &llvm_module);
+
+ HloComputation* entry_computation = hlo_module->entry_computation();
+ IrEmitterUnnested ir_emitter(*module_config, entry_computation,
+ module_config->has_hybrid_result(),
+ &ir_emitter_context);
+ TF_RETURN_IF_ERROR(
+ entry_computation->root_instruction()->Accept(&ir_emitter));
+
+ string ir_module_string_before_opt;
+ legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags();
+ if (VLOG_IS_ON(2) || flags->xla_gpu_embed_ir) {
+ ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module);
+ VLOG(2) << "LLVM module before optimizations:";
+ XLA_VLOG_LINES(2, ir_module_string_before_opt);
+ }
+
+ // Reserve space for the PTX to be generated for this module.
+ string* ptx;
+ {
+ tensorflow::mutex_lock lock(mutex_);
+ generated_ptxes_.emplace_back(MakeUnique<string>());
+ ptx = generated_ptxes_.back().get();
+ }
+ TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, libdevice_dir_));
+
+ VLOG(2) << "LLVM module after optimizations:";
+ XLA_VLOG_LINES(2, llvm_ir::DumpModuleToString(llvm_module));
+ VLOG(2) << "PTX:";
+ XLA_VLOG_LINES(2, *ptx);
+ if (VLOG_IS_ON(2)) {
+ DumpPtxasInfo(*ptx);
+ }
+
+ auto thunk_schedule = MakeUnique<ThunkSchedule>(
+ ir_emitter.ConsumeThunkSequence(), std::move(stream_assignment),
+ hlo_schedule->ThunkLaunchOrder());
+ VLOG(2) << "Printing the thunk schedule...";
+ XLA_VLOG_LINES(2, thunk_schedule->ToString());
+
+ auto* gpu_executable =
+ new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(hlo_module),
+ std::move(module_config), std::move(buffer_assignment),
+ std::move(temp_buffer_offsets));
+ if (flags->xla_gpu_embed_ir) {
+ DCHECK_NE("", ir_module_string_before_opt);
+ gpu_executable->set_ir_module_string(ir_module_string_before_opt);
+ }
+ return std::unique_ptr<Executable>(gpu_executable);
+}
+
+StatusOr<std::vector<std::unique_ptr<Executable>>> GpuCompiler::Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_modules,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ HloDumper dump_hlos, std::vector<se::StreamExecutor*> stream_execs) {
+ return Unimplemented(
+ "Compilation of multiple HLO modules is not yet supported on GPU.");
+}
+
+StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::CompileAheadOfTime(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ const AotCompilationOptions& options) {
+ return Unimplemented("not yet implemented: GpuCompiler::CompileAheadOfTime");
+}
+
+se::Platform::Id GpuCompiler::PlatformId() const {
+ return se::cuda::kCudaPlatformId;
+}
+
+} // namespace gpu
+} // namespace xla
+
+static bool InitModule() {
+ xla::Compiler::RegisterCompilerFactory(se::cuda::kCudaPlatformId, []() {
+ return xla::MakeUnique<xla::gpu::GpuCompiler>();
+ });
+ return true;
+}
+static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
new file mode 100644
index 0000000000..fefa403104
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+namespace gpu {
+
+// The GPU compiler generates efficient GPU executables.
+class GpuCompiler : public Compiler {
+ public:
+ GpuCompiler();
+ ~GpuCompiler() override {}
+
+ StatusOr<std::unique_ptr<Executable>> Compile(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ perftools::gputools::StreamExecutor* stream_exec) override;
+
+ StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
+ std::vector<std::unique_ptr<HloModule>> hlo_module,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_config,
+ HloDumper dump_hlo,
+ std::vector<perftools::gputools::StreamExecutor*> stream_exec) override;
+
+ StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
+ AotCompilationOptions const& options) override;
+
+ perftools::gputools::Platform::Id PlatformId() const override;
+
+ private:
+ // The parent directory of libdevice IR libraries.
+ const string libdevice_dir_;
+
+ // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler
+ // to own them because they need to be alive across the life span of the
+ // StreamExecutor (b/24776264).
+ tensorflow::mutex mutex_;
+ std::vector<std::unique_ptr<string>> generated_ptxes_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_COMPILER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
new file mode 100644
index 0000000000..f654ffd22d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -0,0 +1,454 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+
+#include <set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// A helper class for profiling HLO in the course of GPU program execution.
+// All of the profiling is guarded internally, to avoid the caller needing to
+// have lots of conditionals sprinkled around.
+class HloExecutionProfiler {
+ public:
+ // If profiling is enabled, start an execution timer running.
+ explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile,
+ se::Stream* stream)
+ : do_profile_(do_profile), profile_(profile), stream_(stream) {
+ if (do_profile_) {
+ clock_rate_ghz_ =
+ stream->parent()->GetDeviceDescription().clock_rate_ghz();
+ execution_timer_.reset(new se::Timer(stream->parent()));
+ per_op_timer_.reset(new se::Timer(stream->parent()));
+ stream->InitTimer(execution_timer_.get())
+ .ThenStartTimer(execution_timer_.get());
+ stream->InitTimer(per_op_timer_.get());
+ }
+ }
+
+ // If profiling is enabled, sets the total cycle count on the profile from the
+ // execution timer.
+ ~HloExecutionProfiler() {
+ if (do_profile_) {
+ stream_->ThenStopTimer(execution_timer_.get());
+ stream_->BlockHostUntilDone();
+ profile_->set_total_cycles_executed(execution_timer_->Nanoseconds() *
+ clock_rate_ghz_);
+ }
+ }
+
+ // If profiling is enabled, starts the per-operation timer.
+ void StartOperation() {
+ if (do_profile_) {
+ stream_->ThenStartTimer(per_op_timer_.get());
+ }
+ }
+
+ // If profiling is enabled, stops the per-operation timer and records the time
+ // that the hlo_instruction took to execute in the profile.
+ void FinishOperation(const HloInstruction* hlo_instruction) {
+ if (do_profile_) {
+ stream_->ThenStopTimer(per_op_timer_.get());
+ stream_->BlockHostUntilDone();
+ profile_->AddProfileResult(
+ hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_);
+ }
+ }
+
+ private:
+ const bool do_profile_;
+ double clock_rate_ghz_;
+ HloExecutionProfile* profile_;
+ se::Stream* stream_;
+ std::unique_ptr<se::Timer> execution_timer_;
+ std::unique_ptr<se::Timer> per_op_timer_;
+};
+
+} // namespace
+
+// Implementation note: HLO profiling is always enabled for GPU executables,
+// since we can use timers around thunks.
+GpuExecutable::GpuExecutable(
+ tensorflow::StringPiece ptx, std::unique_ptr<ThunkSchedule> thunk_schedule,
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ std::unique_ptr<BufferAssignment> assignment,
+ std::unique_ptr<TempBufferOffsets> temp_buffer_offsets)
+ : Executable(std::move(hlo_module), std::move(module_config)),
+ ptx_(ptx),
+ thunk_schedule_(std::move(thunk_schedule)),
+ assignment_(std::move(assignment)),
+ temp_buffer_offsets_(std::move(temp_buffer_offsets)) {}
+
+Status GpuExecutable::ExecuteThunks(
+ se::Stream* main_stream, const BufferAllocations& buffer_allocations,
+ HloExecutionProfile* hlo_execution_profile) {
+ bool do_profile = hlo_execution_profile != nullptr;
+ if (do_profile) {
+ LOG(WARNING) << "PROFILING: profiling is enabled";
+ }
+ HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream);
+
+ std::vector<std::unique_ptr<se::Stream>> sub_streams;
+ // Stream 0 indicates `main_stream` and substreams start from stream 1.
+ for (int32 i = 1; i < thunk_schedule_->StreamCount(); ++i) {
+ auto sub_stream = MakeUnique<se::Stream>(main_stream->parent());
+ sub_stream->Init();
+ sub_streams.emplace_back(std::move(sub_stream));
+ }
+
+ std::map<const Thunk*, std::unique_ptr<se::Event>> thunk_to_finish_event;
+ for (Thunk* thunk : thunk_schedule_->TotalOrder()) {
+ TF_RETURN_IF_ERROR(thunk->Initialize(*this));
+ int32 stream_no =
+ thunk_schedule_->StreamNumberForHlo(*thunk->hlo_instruction());
+ se::Stream* stream =
+ (stream_no == 0 ? main_stream : sub_streams[stream_no - 1].get());
+
+ for (const Thunk* dependency : thunk_schedule_->DependsOn(thunk)) {
+ stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get());
+ }
+
+ profiler.StartOperation();
+ VLOG(2) << "Executing the thunk for "
+ << thunk->hlo_instruction()->ToString();
+ TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
+ if (thunk_schedule_->Depended(thunk)) {
+ auto finish_event = MakeUnique<se::Event>(main_stream->parent());
+ finish_event->Init();
+ stream->ThenRecordEvent(finish_event.get());
+ thunk_to_finish_event[thunk] = std::move(finish_event);
+ }
+ profiler.FinishOperation(thunk->hlo_instruction());
+ }
+
+ main_stream->ThenWaitFor(&sub_streams);
+ // Make sure kernels are completed before deallocating temporary buffers.
+ // TODO(b/30100571): we could potentially postpone deallocating the temp
+ // buffers until a different computation is executed.
+ if (!main_stream->BlockHostUntilDone()) {
+ return InternalError("Failed to complete all kernels launched on stream %p",
+ main_stream);
+ }
+
+ return Status::OK();
+}
+
+StatusOr<se::DeviceMemoryBase> GpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ // This ExecuteOnStream overload should only be called if has_hybrid_result is
+ // false.
+ TF_RET_CHECK(!module_config().has_hybrid_result());
+
+ BufferAllocations::Builder buffer_allocations_builder;
+ for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+ ++i) {
+ const BufferAllocation& allocation = assignment_->GetAllocation(i);
+ if (allocation.is_entry_computation_parameter()) {
+ buffer_allocations_builder.RegisterBuffer(
+ i, arguments[allocation.parameter_number()]);
+ }
+ }
+ se::StreamExecutor* executor = stream->parent();
+ TF_ASSIGN_OR_RETURN(auto buffer_allocations,
+ buffer_allocations_builder.Build(
+ *assignment_, *temp_buffer_offsets_,
+ executor->device_ordinal(), memory_allocator));
+
+ TF_RETURN_IF_ERROR(
+ ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile));
+
+ HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* output_allocation,
+ assignment_->GetUniqueTopLevelOutputAllocation());
+ se::DeviceMemoryBase output_buffer_address =
+ buffer_allocations->GetDeviceAddress(output_allocation->index());
+
+ if (ShapeUtil::IsTuple(root->shape())) {
+ std::set<se::DeviceMemoryBase> referred_by_output;
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ // The points-to set of the root is ambiguous so we need to examine the
+ // result data to determine which buffers are contained in the result.
+ TF_ASSIGN_OR_RETURN(
+ TransferManager * transfer_manager,
+ TransferManager::GetForPlatform(executor->platform()));
+ TF_ASSIGN_OR_RETURN(referred_by_output,
+ transfer_manager->GatherBufferPointersFromTuple(
+ executor, output_buffer_address, root->shape()));
+ } else {
+ // The points-to set of the root is unambiguous so it's known statically
+ // which buffers are in the result. Gather these buffers using the root's
+ // points-to set.
+ TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElement(
+ [&referred_by_output, &buffer_allocations, this](
+ const ShapeIndex& /*index*/, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& buffers) {
+ // The points to set is unambiguous so the set should be a
+ // singleton. That is, we know exactly which instruction produced
+ // the array at this element.
+ CHECK_EQ(1, buffers.size());
+ HloInstruction* hlo = buffers[0]->instruction();
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation,
+ this->assignment_->GetUniqueAllocation(
+ hlo, buffers[0]->index()));
+ CHECK(!allocation->is_entry_computation_parameter());
+ referred_by_output.insert(
+ buffer_allocations->GetDeviceAddress(allocation->index()));
+ return Status::OK();
+ }));
+ }
+ TF_RETURN_IF_ERROR(
+ buffer_allocations->TearDown(referred_by_output, *assignment_));
+ } else {
+ // If the computation result is not a tuple, we can delete all temporary
+ // buffers that are not the output.
+ TF_RETURN_IF_ERROR(
+ buffer_allocations->TearDown({output_buffer_address}, *assignment_));
+ }
+ return output_buffer_address;
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> GpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ // This ExecuteOnStream overload should only be called by the LocalService
+ // which sets has_hybrid_result to true.
+ TF_RET_CHECK(module_config().has_hybrid_result());
+
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ return Unimplemented("Points-to set of root instruction is ambiguous");
+ }
+
+ BufferAllocations::Builder buffer_allocations_builder;
+ for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+ ++i) {
+ const BufferAllocation& allocation = assignment_->GetAllocation(i);
+ if (allocation.is_entry_computation_parameter()) {
+ auto param_no = allocation.parameter_number();
+ if (ShapeUtil::IsTuple(arguments[param_no]->shape())) {
+ return Unimplemented("Tuple ShapedBuffer arguments not supported");
+ }
+ buffer_allocations_builder.RegisterBuffer(
+ i, arguments[param_no]->buffer(/*index=*/{}));
+ }
+ }
+ se::StreamExecutor* executor = stream->parent();
+ TF_ASSIGN_OR_RETURN(auto buffer_allocations,
+ buffer_allocations_builder.Build(
+ *assignment_, *temp_buffer_offsets_,
+ executor->device_ordinal(), memory_allocator));
+
+ TF_RETURN_IF_ERROR(
+ ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile));
+
+ HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
+ auto device_ordinal = executor->device_ordinal();
+ TF_ASSIGN_OR_RETURN(auto shaped_buffer,
+ ShapedBuffer::MakeShapedBuffer(
+ root->shape(), executor->platform(), device_ordinal));
+
+ // Copy DeviceMemoryBase values which contain the array(s) of the result into
+ // the respective location in ShapedBuffer.
+ std::set<se::DeviceMemoryBase> buffers_in_result;
+ TF_RETURN_IF_ERROR(
+ shaped_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [&buffer_allocations, &buffers_in_result, &shaped_buffer, this](
+ const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
+ if (is_leaf) {
+ const std::vector<const LogicalBuffer*>& sources =
+ this->GetRootPointsToSet().element(index);
+ // The points to set is unambiguous so the set should be a
+ // singleton. That is, we know exactly which instruction
+ // produced the array at this element.
+ CHECK_EQ(1, sources.size());
+ auto src_hlo = sources[0]->instruction();
+
+ VLOG(4) << "Looking at: " << sources[0];
+
+ // The source instruction should have a non-parameter buffer
+ // assigned.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation,
+ this->assignment_->GetUniqueAllocation(
+ src_hlo, sources[0]->index()));
+ CHECK(!allocation->is_entry_computation_parameter());
+
+ perftools::gputools::DeviceMemoryBase src_base =
+ buffer_allocations->GetDeviceAddress(allocation->index());
+ CHECK(!src_base.is_null() || src_base.size() == 0);
+ shaped_buffer->mutable_buffers()->push_back(src_base);
+ *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1;
+
+ buffers_in_result.insert(src_base);
+ }
+ return Status::OK();
+ }));
+ TF_RETURN_IF_ERROR(
+ buffer_allocations->TearDown(buffers_in_result, *assignment_));
+
+ return std::move(shaped_buffer);
+}
+
+Status GpuExecutable::ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) {
+ se::Stream* stream = run_options->stream();
+ DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+ // This ExecuteOnStream overload should only be called by the LocalService
+ // which sets has_hybrid_result to true.
+ TF_RET_CHECK(module_config().has_hybrid_result());
+
+ // Every array element in the result of the computation must be unambiguously
+ // produced by a single instruction.
+ // This ensures that the buffers inside result_buffer can be assigned without
+ // conflict to the respective instructions because there is a one-to-one
+ // correspondence between hlo instructions and array buffers in the result.
+ if (GetRootPointsToSet().IsAmbiguous()) {
+ return Unimplemented(
+ "Points-to set of root instruction is ambiguous or not distinct");
+ }
+
+ DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape()));
+
+ BufferAllocations::Builder buffer_allocations_builder;
+ for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+ ++i) {
+ const BufferAllocation& allocation = assignment_->GetAllocation(i);
+ if (allocation.is_entry_computation_parameter()) {
+ auto param_no = allocation.parameter_number();
+ if (ShapeUtil::IsTuple(arguments[param_no]->shape())) {
+ return Unimplemented("Tuple ShapedBuffer arguments not supported");
+ }
+ buffer_allocations_builder.RegisterBuffer(
+ i, arguments[param_no]->buffer(/*index=*/{}));
+ }
+ }
+
+ // If two tuple elements point to the same buffer, one of the results in the
+ // result buffer is considered the canonical location while the other result
+ // points to it (instead of, say, making a copy of the result).
+ // buffer_index_to_shape_index maps a buffer index to its canonical location
+ // in the result buffer.
+ std::unordered_map<BufferAllocation::Index, size_t>
+ buffer_index_to_shape_index;
+
+ // Register DeviceMemoryBase values in result_buffer to their corresponding
+ // buffer indices. These buffers will not be allocated in the call to
+ // BufferAllocationsBuilder::Build.
+ std::set<se::DeviceMemoryBase> buffers_in_result;
+ TF_RETURN_IF_ERROR(
+ result_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [&buffer_allocations_builder, &buffers_in_result,
+ &buffer_index_to_shape_index, result_buffer, this](
+ const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
+ if (is_leaf) {
+ const std::vector<const LogicalBuffer*>& sources =
+ this->GetRootPointsToSet().element(index);
+ // The points to set is unambiguous so the set should be a
+ // singleton. That is, we know exactly which instruction
+ // produced the array at this element.
+ CHECK_EQ(1, sources.size());
+ auto src_hlo = sources[0]->instruction();
+
+ VLOG(4) << "Looking at: " << sources[0];
+
+ // The source instruction should have a non-parameter buffer
+ // assigned.
+ TF_ASSIGN_OR_RETURN(const BufferAllocation* allocation,
+ this->assignment_->GetUniqueAllocation(
+ src_hlo, sources[0]->index()));
+ CHECK(!allocation->is_entry_computation_parameter());
+
+ auto insert_result = buffer_index_to_shape_index.emplace(
+ allocation->index(), *buffer_entry);
+ if (insert_result.second) {
+ // The points-to set is distinct so this buffer should not
+ // have been assigned in a previous invocation of this
+ // lambda.
+ perftools::gputools::DeviceMemoryBase memory_base =
+ result_buffer->buffer(index);
+ CHECK(!memory_base.is_null());
+ buffer_allocations_builder.RegisterBuffer(
+ allocation->index(), memory_base);
+ buffers_in_result.insert(memory_base);
+ } else {
+ // Record the fact that this tuple element is identical to
+ // some
+ // prior result.
+ *buffer_entry = insert_result.first->second;
+ }
+ }
+ return Status::OK();
+ }));
+
+ se::StreamExecutor* executor = stream->parent();
+ auto device_ordinal = executor->device_ordinal();
+ TF_ASSIGN_OR_RETURN(
+ auto buffer_allocations,
+ buffer_allocations_builder.Build(*assignment_, *temp_buffer_offsets_,
+ device_ordinal, memory_allocator));
+
+ TF_RETURN_IF_ERROR(
+ ExecuteThunks(stream, *buffer_allocations, hlo_execution_profile));
+
+ return buffer_allocations->TearDown(buffers_in_result, *assignment_);
+}
+
+StatusOr<se::DeviceMemoryBase> GpuExecutable::ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ // TODO(b/30671675): Implement asynchronous execution mode.
+ return Unimplemented(
+ "Asynchronous execution on stream is not yet supported on GPU.");
+}
+
+const PointsToSet& GpuExecutable::GetRootPointsToSet() const {
+ return assignment_->points_to_analysis().GetPointsToSet(
+ module().entry_computation()->root_instruction());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
new file mode 100644
index 0000000000..2343d264de
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -0,0 +1,130 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// GPU-targeting implementation of the XLA Executable interface.
+//
+// Launches the given CUDA kernel via the StreamExecutor.
+//
+// This is an immutable data type after initialization, and thus thread safe.
+class GpuExecutable : public Executable {
+ public:
+ GpuExecutable(tensorflow::StringPiece ptx,
+ std::unique_ptr<ThunkSchedule> thunk_schedule,
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ std::unique_ptr<BufferAssignment> assignment,
+ std::unique_ptr<TempBufferOffsets> temp_buffer_offsets);
+
+ // This should be called after set_ir_module_string.
+ const string& ir_module_string() const { return ir_module_string_; }
+
+ // This should be called before ExecuteOnStream.
+ void set_ir_module_string(const string& ir_module_string) {
+ ir_module_string_ = ir_module_string;
+ }
+
+ // Returns the compiled PTX for the computation.
+ tensorflow::StringPiece ptx() const { return ptx_; }
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ Status ExecuteOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result_buffer,
+ HloExecutionProfile* hlo_execution_profile) override;
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
+ const ExecutableRunOptions* run_options,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments) override;
+
+ private:
+ Status ExecuteThunks(perftools::gputools::Stream* stream,
+ const BufferAllocations& buffer_allocations,
+ HloExecutionProfile* hlo_execution_profile);
+
+ // Returns the points-to set of the root instruction of the entry
+ // computation. Uses points-to analysis from buffer assignment.
+ const PointsToSet& GetRootPointsToSet() const;
+
+ // The LLVM IR, in string format, of the unoptimized module generated for this
+ // GpuExecutable. We save a string instead of an llvm::Module* because leaving
+ // llvm::Module* in a singleton can cause the heap checker to emit false
+ // positives.
+ //
+ // This string should be modified only before ExecuteOnStream.
+ string ir_module_string_;
+
+ // The reference to the compiled PTX for the computation.
+ const tensorflow::StringPiece ptx_;
+
+ // The thunks to be invoked by this GpuExecutable. They are generated by the
+ // IrEmitter.
+ const std::unique_ptr<ThunkSchedule> thunk_schedule_;
+
+ // Owns the buffer data at runtime. It provides information to allocate
+ // memory for every output/temp buffers.
+ const std::unique_ptr<BufferAssignment> assignment_;
+
+ // Owns the mapping from temporary buffers to their offsets in the temp-buffer
+ // memory block.
+ const std::unique_ptr<TempBufferOffsets> temp_buffer_offsets_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
new file mode 100644
index 0000000000..c56b0ee9ca
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -0,0 +1,207 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <deque>
+#include <memory>
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// An HLO partial ordering based on the actual stream assignment and thunk
+// launch order.
+class GpuHloOrdering : public PredecessorHloOrdering {
+ public:
+ GpuHloOrdering(const HloModule* module,
+ const StreamAssignment& stream_assignment,
+ const std::vector<const HloInstruction*>& thunk_launch_order);
+ ~GpuHloOrdering() override = default;
+
+ string ToString() const override { return ToStringHelper("GpuHloOrdering"); }
+};
+
+GpuHloOrdering::GpuHloOrdering(
+ const HloModule* module, const StreamAssignment& stream_assignment,
+ const std::vector<const HloInstruction*>& thunk_launch_order)
+ : PredecessorHloOrdering(module) {
+ // The ordering of instructions for the entry computation is determined by the
+ // total order of thunk launches, and stream assignment. Instructions are
+ // sequential within a stream and concurrent across streams. In addition, the
+ // GpuExecutable adds cross-stream dependency edges to ensure each instruction
+ // waits for its operands before executing.
+ //
+ // The predecessor map is built incrementally, in thunk launch
+ // order. We record the instructions already visited per stream in
+ // 'instructions_per_stream'. This lets us quickly determine the
+ // same-stream predecessors of each instruction. To capture
+ // cross-stream dependency edges, we use the predecessor map to
+ // insert each operand as well as its transitive closure of
+ // dependencies.
+
+ // Compute the set of all instructions we will want to set reachability on
+ auto predecessor_map = MakeUnique<HloComputation::ReachabilityMap>(
+ module->entry_computation()->MakeInstructionPostOrder());
+
+ std::vector<std::vector<const HloInstruction*>> instructions_per_stream(
+ stream_assignment.StreamCount());
+
+ for (const HloInstruction* hlo : thunk_launch_order) {
+ if (stream_assignment.HasStreamAssigned(*hlo)) {
+ // All ops already queued on the same stream are predecessors.
+ const int stream_no = stream_assignment.StreamNumberForHlo(*hlo);
+ for (const HloInstruction* inst : instructions_per_stream[stream_no]) {
+ predecessor_map->SetReachable(hlo, inst);
+ }
+ // All operands and their transitive predecessors are predecessors. Each
+ // operand must already exist in 'predecessor_map', since we're iterating
+ // in thunk launch order.
+ for (const HloInstruction* operand : hlo->operands()) {
+ predecessor_map->SetReachableAndTransitiveClosure(hlo, operand);
+ }
+ instructions_per_stream[stream_no].push_back(hlo);
+ } else {
+ // Only parameters and constants don't have an assigned stream, since they
+ // don't require a thunk. These ops don't have any predecessors.
+ CHECK(hlo->opcode() == HloOpcode::kParameter ||
+ hlo->opcode() == HloOpcode::kConstant);
+ CHECK_EQ(hlo->operand_count(), 0);
+ }
+ }
+ strict_predecessors_.emplace(module->entry_computation(),
+ std::move(predecessor_map));
+
+ // The ordering of instructions in subcomputations is based solely on data
+ // dependencies. I.e. the strict predecessors of each subcomputation
+ // instruction is its transitive operands.
+ //
+ // TODO(toddw): Each subcomputation is actually emitted as a function in
+ // DFS
+ // postorder, so we can do better and establish the total order here. We
+ // don't
+ // do that yet since it's hard to ensure that the order here is the order
+ // used
+ // by IrEmitterNested. And mismatched ordering bugs would be hard to find.
+ for (auto& computation : module->computations()) {
+ if (computation.get() != module->entry_computation()) {
+ strict_predecessors_.emplace(computation.get(),
+ computation->ComputeTransitiveOperands());
+ }
+ }
+}
+
+// Computes a topological launch_order based on depth-first order, visiting
+// operands in essentially an arbitrary order.
+//
+// TODO(b/32006145): Use an ordering that minimizes memory pressure.
+tensorflow::Status DFSLaunchOrder(
+ const HloComputation* computation,
+ std::vector<const HloInstruction*>* launch_order) {
+ return computation->root_instruction()->Accept(
+ [launch_order](HloInstruction* hlo) {
+ launch_order->push_back(hlo);
+ return tensorflow::Status::OK();
+ });
+}
+
+// Computes a topological launch_order that is close to a breadth-first
+// order. This heuristic works well for graphs where concurrent kernels are
+// located at the same layer. It can often reduce dependency between concurrent
+// GEMMs due to intra-stream total orders. E.g. consider the following HLO
+// graph where the numbers in the parens indicate the stream assigned to each
+// HLO.
+//
+// A(0) -> D(0) -> E(1)
+// |
+// v
+// B(0)
+// |
+// v
+// C(0)
+//
+// If the total order is A,B,C,D,E, then C and E would be sequentialized
+// because C completes before D starts in stream 0, and E depends on D.
+// However, if the total order is A,B,D,C,E, then C and E can run
+// concurrently.
+void BFSLaunchOrder(const HloComputation* computation,
+ std::vector<const HloInstruction*>* launch_order) {
+ // This topological sort uses two data structures:
+ // 1. `incoming_edge_count` which keeps track of the number of incoming
+ // edges to each HLO;
+ // 2. `queue` which contains all HLOs with no incoming edges.
+ //
+ // The sorting algorithm repeatedly pops the top from the queue and deletes
+ // that HLO from the graph, making more HLOs incoming-edge free.
+ std::deque<const HloInstruction*> queue;
+ std::unordered_map<const HloInstruction*, int64> incoming_edge_count;
+ for (const auto& hlo : computation->instructions()) {
+ if (hlo->operand_count() == 0) {
+ queue.push_back(hlo.get());
+ } else {
+ incoming_edge_count[hlo.get()] =
+ std::set<HloInstruction*>(hlo->operands().begin(),
+ hlo->operands().end())
+ .size();
+ }
+ }
+
+ while (!queue.empty()) {
+ const HloInstruction* x = queue.front();
+ queue.pop_front();
+ launch_order->push_back(x);
+ for (const HloInstruction* y : x->users()) {
+ --incoming_edge_count[y];
+ if (incoming_edge_count[y] == 0) {
+ queue.push_back(y);
+ }
+ }
+ }
+}
+
+} // end namespace
+
+HloSchedule::HloSchedule() {}
+
+/* static */
+StatusOr<std::unique_ptr<HloSchedule>> HloSchedule::Build(
+ const HloModule& module, const StreamAssignment& stream_assignment) {
+ std::unique_ptr<HloSchedule> schedule(new HloSchedule);
+
+ // Initialize thunk_launch_order_, the total order of thunk launches.
+ const HloComputation* computation = module.entry_computation();
+ if (stream_assignment.StreamCount() == 1) {
+ // DFS tends to increase buffer reuse, reducing memory usage. All kernels
+ // are launched on a single stream, so there's no loss of concurrency.
+ TF_RETURN_IF_ERROR(
+ DFSLaunchOrder(computation, &schedule->thunk_launch_order_));
+ } else {
+ // BFS tends to increase concurrency, but also increases memory usage.
+ BFSLaunchOrder(computation, &schedule->thunk_launch_order_);
+ }
+
+ schedule->hlo_ordering_ = MakeUnique<GpuHloOrdering>(
+ &module, stream_assignment, schedule->thunk_launch_order_);
+
+ return std::move(schedule);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
new file mode 100644
index 0000000000..42d9051aed
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
@@ -0,0 +1,67 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/buffer_liveness.h"
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+// Determines the schedule of HLO instructions, represented by the total order
+// of thunk launches, and the partial order of HLO instructions. The HLO
+// instructions are only partially ordered, despite the total ordering of thunk
+// launches, because thunks may be scheduled onto concurrent streams. This
+// schedule is used by BufferAssigner to determine buffer liveness (i.e. to
+// minimize allocations), and also by ThunkSchedule to determine the thunk
+// launch order.
+class HloSchedule {
+ public:
+ // Constructs an HloSchedule for the given module, based on the given stream
+ // assignment.
+ static StatusOr<std::unique_ptr<HloSchedule>> Build(
+ const HloModule& module, const StreamAssignment& stream_assignment);
+
+ // Returns the total order of thunk launches, represented in terms of HLO
+ // instructions.
+ const std::vector<const HloInstruction*>& ThunkLaunchOrder() const {
+ return thunk_launch_order_;
+ }
+
+ // Returns the partial order of HLO instructions. This method may only be
+ // called once. The order is based on the total order of thunk lanches, the
+ // stream assignment, and the data dependencies in the HLO DAG.
+ std::unique_ptr<HloOrdering> ConsumeHloOrdering() {
+ return std::move(hlo_ordering_);
+ }
+
+ private:
+ HloSchedule();
+
+ std::vector<const HloInstruction*> thunk_launch_order_;
+ std::unique_ptr<HloOrdering> hlo_ordering_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
new file mode 100644
index 0000000000..174982a6ce
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule_test.cc
@@ -0,0 +1,368 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
+
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+class HloScheduleTest : public HloTestBase {
+ protected:
+ typedef std::vector<const HloInstruction*> hlovec;
+
+ // Pre-canned shapes.
+ Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
+};
+
+// Test of a single stream, where data dependencies fully determine the
+// execution order.
+TEST_F(HloScheduleTest, SequentialMatMul) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
+ HloInstruction* dot1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction* dot2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(dot2));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ EXPECT_EQ(streams->StreamNumberForHlo(*dot1),
+ streams->StreamNumberForHlo(*dot2));
+
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+ EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, dot1, z, dot2}));
+
+ // Parameters x,y,z are mutually unordered, while dot1 and dot2 are
+ // transitively ordered by operands.
+ auto order = schedule->ConsumeHloOrdering();
+ EXPECT_TRUE(order->ExecutesBefore(x, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(x, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(z, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(dot1, dot2));
+
+ EXPECT_FALSE(order->ExecutesBefore(x, x));
+ EXPECT_FALSE(order->ExecutesBefore(x, y));
+ EXPECT_FALSE(order->ExecutesBefore(x, z));
+ EXPECT_FALSE(order->ExecutesBefore(y, x));
+ EXPECT_FALSE(order->ExecutesBefore(y, y));
+ EXPECT_FALSE(order->ExecutesBefore(y, z));
+ EXPECT_FALSE(order->ExecutesBefore(z, x));
+ EXPECT_FALSE(order->ExecutesBefore(z, y));
+ EXPECT_FALSE(order->ExecutesBefore(z, z));
+ EXPECT_FALSE(order->ExecutesBefore(z, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, z));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, z));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
+}
+
+// Test of a single stream, where data dependencies do not fully determine the
+// execution order, but the stream assignment does.
+TEST_F(HloScheduleTest, SequentialAdd) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, x, y));
+ HloInstruction* add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, y, z));
+ HloInstruction* add3 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, add1, add2));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(add3));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ EXPECT_EQ(streams->StreamNumberForHlo(*add1),
+ streams->StreamNumberForHlo(*add2));
+ EXPECT_EQ(streams->StreamNumberForHlo(*add1),
+ streams->StreamNumberForHlo(*add3));
+
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+ EXPECT_EQ(schedule->ThunkLaunchOrder(), hlovec({x, y, add1, z, add2, add3}));
+
+ // Parameters x,y,z are mutually unordered, while add1, add2 and add3 are
+ // transitively ordered by operands.
+ auto order = schedule->ConsumeHloOrdering();
+ EXPECT_TRUE(order->ExecutesBefore(x, add1));
+ EXPECT_TRUE(order->ExecutesBefore(x, add3));
+ EXPECT_TRUE(order->ExecutesBefore(y, add1));
+ EXPECT_TRUE(order->ExecutesBefore(y, add2));
+ EXPECT_TRUE(order->ExecutesBefore(y, add3));
+ EXPECT_TRUE(order->ExecutesBefore(z, add2));
+ EXPECT_TRUE(order->ExecutesBefore(z, add3));
+ EXPECT_TRUE(order->ExecutesBefore(add1, add3));
+ EXPECT_TRUE(order->ExecutesBefore(add2, add3));
+ // The HLO graph does not define an ordering for add1 and add2, but their
+ // assignment onto the same stream does define an ordering.
+ if (order->ExecutesBefore(add1, add2)) {
+ EXPECT_FALSE(order->ExecutesBefore(add2, add1));
+ } else {
+ EXPECT_TRUE(order->ExecutesBefore(add2, add1));
+ EXPECT_FALSE(order->ExecutesBefore(add1, add2));
+ }
+
+ EXPECT_FALSE(order->ExecutesBefore(x, x));
+ EXPECT_FALSE(order->ExecutesBefore(x, y));
+ EXPECT_FALSE(order->ExecutesBefore(x, z));
+ EXPECT_FALSE(order->ExecutesBefore(y, x));
+ EXPECT_FALSE(order->ExecutesBefore(y, y));
+ EXPECT_FALSE(order->ExecutesBefore(y, z));
+ EXPECT_FALSE(order->ExecutesBefore(z, x));
+ EXPECT_FALSE(order->ExecutesBefore(z, y));
+ EXPECT_FALSE(order->ExecutesBefore(z, z));
+ EXPECT_FALSE(order->ExecutesBefore(x, add2));
+ EXPECT_FALSE(order->ExecutesBefore(z, add1));
+ EXPECT_FALSE(order->ExecutesBefore(add1, x));
+ EXPECT_FALSE(order->ExecutesBefore(add1, y));
+ EXPECT_FALSE(order->ExecutesBefore(add1, z));
+ EXPECT_FALSE(order->ExecutesBefore(add1, add1));
+ EXPECT_FALSE(order->ExecutesBefore(add2, x));
+ EXPECT_FALSE(order->ExecutesBefore(add2, y));
+ EXPECT_FALSE(order->ExecutesBefore(add2, z));
+ EXPECT_FALSE(order->ExecutesBefore(add2, add2));
+}
+
+// Test of two streams.
+TEST_F(HloScheduleTest, ConcurrentMatMul) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* dot1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction* dot2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(add));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ EXPECT_NE(streams->StreamNumberForHlo(*dot1),
+ streams->StreamNumberForHlo(*dot2));
+
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+ EXPECT_TRUE(schedule->ThunkLaunchOrder() == hlovec({x, y, dot1, dot2, add}) ||
+ schedule->ThunkLaunchOrder() == hlovec({x, y, dot2, dot1, add}));
+
+ // Parameters x,y are mutually unordered, while dot1, dot2 and add are
+ // transitively ordered by operands.
+ auto order = schedule->ConsumeHloOrdering();
+ EXPECT_TRUE(order->ExecutesBefore(x, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(x, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot1));
+ EXPECT_TRUE(order->ExecutesBefore(y, dot2));
+ EXPECT_TRUE(order->ExecutesBefore(dot1, add));
+ EXPECT_TRUE(order->ExecutesBefore(dot2, add));
+
+ EXPECT_FALSE(order->ExecutesBefore(x, x));
+ EXPECT_FALSE(order->ExecutesBefore(x, y));
+ EXPECT_FALSE(order->ExecutesBefore(y, x));
+ EXPECT_FALSE(order->ExecutesBefore(y, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot1, dot2));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, x));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, y));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(dot2, dot2));
+ EXPECT_FALSE(order->ExecutesBefore(add, x));
+ EXPECT_FALSE(order->ExecutesBefore(add, y));
+ EXPECT_FALSE(order->ExecutesBefore(add, dot1));
+ EXPECT_FALSE(order->ExecutesBefore(add, dot2));
+ EXPECT_FALSE(order->ExecutesBefore(add, add));
+}
+
+// Test of multiple streams.
+TEST_F(HloScheduleTest, LatticeMatMul) {
+ // d00 -- layer 0
+ // / \
+ // d10 d11 -- layer 1
+ // / \ / \
+ // d20 d21 d22 -- layer 2
+ // \ / \ /
+ // d30 d31 -- layer 3
+ // \ /
+ // d40 -- layer 4
+ HloComputation::Builder builder("entry_computation");
+ std::vector<HloInstruction*> params;
+ for (int i = 0; i < 6; ++i) {
+ params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
+ i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ }
+ HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32_2x2_, HloOpcode::kDot, params[2], params[3]));
+ HloInstruction* d10 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
+ HloInstruction* d11 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
+ HloInstruction* d20 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
+ HloInstruction* d21 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
+ HloInstruction* d22 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
+ HloInstruction* d30 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
+ HloInstruction* d31 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
+ HloInstruction* d40 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(d40));
+
+ std::unique_ptr<StreamAssignment> streams = AssignStreams(module);
+ // The two dots on layer 1 are concurrent.
+ EXPECT_NE(streams->StreamNumberForHlo(*d10),
+ streams->StreamNumberForHlo(*d11));
+ // The three dots on layer 2 are concurrent.
+ EXPECT_NE(streams->StreamNumberForHlo(*d20),
+ streams->StreamNumberForHlo(*d21));
+ EXPECT_NE(streams->StreamNumberForHlo(*d20),
+ streams->StreamNumberForHlo(*d22));
+ EXPECT_NE(streams->StreamNumberForHlo(*d21),
+ streams->StreamNumberForHlo(*d22));
+ // The two dots on layer 3 are concurrent.
+ EXPECT_NE(streams->StreamNumberForHlo(*d30),
+ streams->StreamNumberForHlo(*d31));
+
+ // We don't check the thunk launch order, since there are many valid total
+ // orders, and it's annoying to express.
+ auto schedule = HloSchedule::Build(module, *streams).ConsumeValueOrDie();
+
+ auto order = schedule->ConsumeHloOrdering();
+ const hlovec all_params(
+ {params[0], params[1], params[2], params[3], params[4], params[5]});
+ const hlovec all_ops({d00, d10, d11, d20, d21, d22, d30, d31, d40});
+
+ // Parameters are mutually unordered, and never execute before ops.
+ for (const HloInstruction* param : all_params) {
+ for (const HloInstruction* param2 : all_params) {
+ EXPECT_FALSE(order->ExecutesBefore(param, param2));
+ }
+ for (const HloInstruction* op : all_ops) {
+ EXPECT_FALSE(order->ExecutesBefore(op, param));
+ }
+ }
+
+ // Check ordering of params before ops.
+ for (const HloInstruction* op : all_ops) {
+ if (op == d20 || op == d30 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(params[0], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[0], op));
+ }
+ if (op != d00 && op != d11 && op != d22) {
+ EXPECT_TRUE(order->ExecutesBefore(params[1], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[1], op));
+ }
+ EXPECT_TRUE(order->ExecutesBefore(params[2], op));
+ EXPECT_TRUE(order->ExecutesBefore(params[3], op));
+ if (op != d00 && op != d10 && op != d20) {
+ EXPECT_TRUE(order->ExecutesBefore(params[4], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[4], op));
+ }
+ if (op == d22 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(params[5], op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(params[5], op));
+ }
+ }
+
+ // Check ordering of ops before ops.
+ for (const HloInstruction* op : all_ops) {
+ if (op != d00) {
+ EXPECT_TRUE(order->ExecutesBefore(d00, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d00, op));
+ }
+
+ if (op == d20 || op == d21 || op == d30 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d10, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d10, op));
+ }
+
+ if (op == d21 || op == d22 || op == d30 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d11, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d11, op));
+ }
+
+ if (op == d30 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d20, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d20, op));
+ }
+
+ if (op == d30 || op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d21, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d21, op));
+ }
+
+ if (op == d31 || op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d22, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d22, op));
+ }
+
+ if (op == d40) {
+ EXPECT_TRUE(order->ExecutesBefore(d30, op));
+ EXPECT_TRUE(order->ExecutesBefore(d31, op));
+ } else {
+ EXPECT_FALSE(order->ExecutesBefore(d30, op));
+ EXPECT_FALSE(order->ExecutesBefore(d31, op));
+ }
+
+ EXPECT_FALSE(order->ExecutesBefore(d40, op));
+ }
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
new file mode 100644
index 0000000000..accc406c76
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+void HloToIrBindings::EmitBasePointersForHlos(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos) {
+ // I/O HLOs are bound to the arguments of the current IR function. I.e.,
+ //
+ // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
+ llvm::Function* function = ir_builder_->GetInsertBlock()->getParent();
+ CHECK_EQ(io_hlos.size() + 1, function->arg_size());
+
+ // An HLO can have duplicated operands. This data structure remembers which
+ // operand HLOs are already bound to avoid rebinding the same HLO.
+ std::set<const HloInstruction*> already_bound_for_this_function;
+ auto arg_iter = function->arg_begin();
+ for (const auto* io_hlo : io_hlos) {
+ if (!already_bound_for_this_function.count(io_hlo)) {
+ if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
+ BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
+ } else {
+ BindHloToIrValue(*io_hlo, &*arg_iter);
+ }
+ already_bound_for_this_function.insert(io_hlo);
+ }
+ ++arg_iter;
+ }
+
+ temp_buffer_base_ = &*arg_iter;
+ temp_buffer_base_->setName("temp_buffer");
+
+ for (auto* non_io_hlo : non_io_hlos) {
+ if (already_bound_for_this_function.count(non_io_hlo)) {
+ continue;
+ }
+ already_bound_for_this_function.insert(non_io_hlo);
+
+ if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
+ if (!is_nested_) {
+ // Lookup allocation GetTupleElement operand.
+ const BufferAllocation* allocation =
+ buffer_assignment_
+ ->GetUniqueTopLevelAllocation(LatestNonGteAncestor(non_io_hlo))
+ .ConsumeValueOrDie();
+ // We are not in a nested context, so check non-thread-local allocation.
+ CHECK(!allocation->is_thread_local());
+ int64 offset = temp_buffer_offsets_->GetOffset(allocation->index());
+ CHECK_NE(nullptr, temp_buffer_base_);
+ // Emit IR for GetTupleElement instruction and bind to emitted value.
+ llvm::Value* base_ptr = ir_builder_->CreateInBoundsGEP(
+ temp_buffer_base_, ir_builder_->getInt64(offset));
+ BindHloToIrValue(*non_io_hlo,
+ EmitGetTupleElement(non_io_hlo, base_ptr));
+ }
+ continue;
+ }
+
+ if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) {
+ continue;
+ }
+
+ // A non-IO HLO with a buffer is bound to
+ // (1) an alloca if it is thread-local, or
+ // (2) an internal pointer in temp_buffer_base according to its offset.
+ const BufferAllocation* allocation =
+ buffer_assignment_->GetUniqueTopLevelAllocation(non_io_hlo)
+ .ConsumeValueOrDie();
+ if (allocation->is_thread_local()) {
+ llvm::Type* pointee_type =
+ llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_);
+ BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type));
+ } else {
+ int64 offset = temp_buffer_offsets_->GetOffset(allocation->index());
+ CHECK_NE(nullptr, temp_buffer_base_);
+ BindHloToIrValue(*non_io_hlo,
+ ir_builder_->CreateInBoundsGEP(
+ temp_buffer_base_, ir_builder_->getInt64(offset)));
+ }
+ }
+}
+
+llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
+ llvm::Value* base_ptr) {
+ // TODO(b/26344050): tighten the alignment based on the real element type.
+ if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
+ return llvm_ir::EmitGetTupleElement(
+ gte->shape(), gte->tuple_index(), /*alignment=*/1,
+ GetTypedIrValue(*gte->operand(0), base_ptr), ir_builder_);
+ }
+ return llvm_ir::EmitGetTupleElement(
+ gte->shape(), gte->tuple_index(), /*alignment=*/1,
+ EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_);
+}
+
+llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
+ llvm::Value* ir_value) {
+ llvm::Type* pointee_type = llvm_ir::ShapeToIrType(hlo.shape(), ir_builder_);
+ llvm::Type* dest_type = pointee_type->getPointerTo();
+
+ llvm::Value* typed_ir_value;
+ if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
+ typed_ir_value = llvm::ConstantExpr::getBitCast(
+ llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
+ } else {
+ typed_ir_value =
+ ir_builder_->CreateBitCast(ir_value, pointee_type->getPointerTo());
+ }
+ string ir_value_name = llvm_ir::SanitizeIrName(hlo.name());
+ ir_value->setName(llvm_ir::AsStringRef(ir_value_name + ".raw"));
+ typed_ir_value->setName(llvm_ir::AsStringRef(ir_value_name + ".typed"));
+ return typed_ir_value;
+}
+
+void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
+ llvm::Value* ir_value) {
+ VLOG(2) << "Binding " << hlo.ToString();
+ InsertOrDie(&base_ptrs_, &hlo, GetTypedIrValue(hlo, ir_value));
+}
+
+llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo) {
+ llvm_ir::IrArray ir_array(GetBasePointer(hlo), hlo.shape());
+ alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array);
+ return ir_array;
+}
+
+void HloToIrBindings::UnbindAllLocalIrValues() {
+ std::vector<const HloInstruction*> hlos_to_unbind;
+ for (auto& key_value : base_ptrs_) {
+ if (!llvm::isa<llvm::GlobalVariable>(
+ key_value.second->stripPointerCasts())) {
+ hlos_to_unbind.push_back(key_value.first);
+ }
+ }
+ for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
+ VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
+ base_ptrs_.erase(hlo_to_unbind);
+ }
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
new file mode 100644
index 0000000000..1e3b268423
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_
+
+#include <unordered_map>
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+namespace gpu {
+
+// This class encapsulates the bindings between HloInstructions and LLVM IR
+// values that represent their addresses.
+class HloToIrBindings {
+ public:
+ HloToIrBindings(const HloModule& module,
+ const BufferAssignment* buffer_assignment,
+ const TempBufferOffsets* temp_buffer_offsets,
+ llvm::IRBuilder<>* ir_builder, bool is_nested)
+ : buffer_assignment_(buffer_assignment),
+ temp_buffer_offsets_(temp_buffer_offsets),
+ is_nested_(is_nested),
+ ir_builder_(ir_builder),
+ alias_analysis_(module, *buffer_assignment_,
+ &ir_builder_->getContext()) {}
+
+ void EmitBasePointersForHlos(
+ tensorflow::gtl::ArraySlice<const HloInstruction*> io_hlos,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> non_io_hlos);
+
+ // Rebinds the given HLO to the LLVM IR value that represent its address.
+ void BindHloToIrValue(const HloInstruction& hlo, llvm::Value* ir_value);
+
+ // Unbinds all IR values that's defined in an LLVM function, e.g., function
+ // arguments and stack variables. Global variables will be kept in bindings_.
+ //
+ // This method is called after emitting code for each top-level HLO. The local
+ // IR values are out of scope at that point and should not be used.
+ void UnbindAllLocalIrValues();
+
+ // Returns whether `hlo` is bound to an LLVM IR value.
+ bool BoundToIrValue(const HloInstruction& hlo) const {
+ return base_ptrs_.count(&hlo);
+ }
+
+ llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; }
+
+ // A helper method that returns the base pointer of the IrArray for "inst".
+ llvm::Value* GetBasePointer(const HloInstruction& hlo) const {
+ auto it = base_ptrs_.find(&hlo);
+ CHECK(it != base_ptrs_.end());
+ return it->second;
+ }
+
+ // Return the underlying IrArray of the output of the given instruction.
+ llvm_ir::IrArray GetIrArray(const HloInstruction& hlo);
+
+ private:
+ // Emits IR to resolve (possibly) recursive GetTupleElement instructions.
+ llvm::Value* EmitGetTupleElement(const HloInstruction* gte,
+ llvm::Value* base_ptr);
+
+ // Returns an llvm typed ir representation of 'ir_value' based on 'hlo' shape.
+ llvm::Value* GetTypedIrValue(const HloInstruction& hlo,
+ llvm::Value* ir_value);
+
+ const BufferAssignment* buffer_assignment_;
+
+ const TempBufferOffsets* temp_buffer_offsets_;
+
+ const bool is_nested_;
+
+ llvm::IRBuilder<>* ir_builder_;
+
+ // Stores the underlying llvm::IrArray for each HloInstruction.
+ std::unordered_map<const HloInstruction*, llvm::Value*> base_ptrs_;
+
+ // The address of the memory block that contains all temporary buffers.
+ llvm::Value* temp_buffer_base_;
+
+ llvm_ir::AliasAnalysis alias_analysis_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HLO_TO_IR_BINDINGS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
new file mode 100644
index 0000000000..91fd7ae77a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+bool IsFusile(const HloInstruction& hlo) {
+ return (hlo.IsElementwise() && hlo.operand_count() > 0) ||
+ hlo.opcode() == HloOpcode::kBroadcast ||
+ hlo.opcode() == HloOpcode::kConcatenate ||
+ hlo.opcode() == HloOpcode::kDynamicSlice ||
+ hlo.opcode() == HloOpcode::kDynamicUpdateSlice ||
+ hlo.opcode() == HloOpcode::kFusion ||
+ hlo.opcode() == HloOpcode::kGetTupleElement ||
+ hlo.opcode() == HloOpcode::kPad ||
+ hlo.opcode() == HloOpcode::kReduce ||
+ hlo.opcode() == HloOpcode::kReduceWindow ||
+ hlo.opcode() == HloOpcode::kReshape ||
+ hlo.opcode() == HloOpcode::kSlice ||
+ hlo.opcode() == HloOpcode::kTranspose;
+}
+
+} // namespace
+
+bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
+ int64 operand_index) {
+ HloInstruction* producer = consumer->mutable_operand(operand_index);
+
+ // Do not fuse to-vector reduction into other consumers. They should be
+ // unfused or the root of a kInput fusion.
+ if (IsReductionToVector(*producer)) {
+ return false;
+ }
+
+ // We can't fuse library calls, so if a user of such an op could become a
+ // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for
+ // further rationale.
+ if (producer->CouldBeBitcast() &&
+ ImplementedAsLibraryCall(*producer->operand(0))) {
+ return false;
+ }
+
+ // We may need to know original operand layout to emit input fusion, and so
+ // far, we merely use the layout of an operand of the fusion node, which means
+ // we must fuse only elementwise operations. This restriction should be lifted
+ // later if we need to fuse other operations, e.g. transpose, for performance.
+ if ((IsReductionToVector(*consumer) ||
+ (HloOpcode::kFusion == consumer->opcode() &&
+ HloInstruction::FusionKind::kInput == consumer->fusion_kind())) &&
+ !producer->IsElementwise()) {
+ return false;
+ }
+
+ return IsFusile(*producer) && IsFusile(*consumer) &&
+ InstructionFusion::ShouldFuse(consumer, operand_index);
+}
+
+HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) {
+ if (IsReductionToVector(*consumer)) {
+ return HloInstruction::FusionKind::kInput;
+ }
+ if (HloOpcode::kFusion == consumer->opcode()) {
+ return consumer->fusion_kind();
+ }
+ return InstructionFusion::ChooseKind(producer, consumer);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.h b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
new file mode 100644
index 0000000000..21f3b542a2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/instruction_fusion.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuInstructionFusion : public InstructionFusion {
+ public:
+ explicit GpuInstructionFusion(bool may_duplicate)
+ : InstructionFusion(may_duplicate) {}
+
+ bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;
+
+ HloInstruction::FusionKind ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INSTRUCTION_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
new file mode 100644
index 0000000000..c58af04bad
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -0,0 +1,126 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+
+using InstructionFusionTest = HloTestBase;
+
+TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(reshape2, computation->root_instruction());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(transpose2, computation->root_instruction());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input"));
+ auto filter = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter"));
+
+ Window conv_window;
+ WindowDimension* conv_window_row = conv_window.add_dimensions();
+ conv_window_row->set_size(1);
+ WindowDimension* conv_window_col = conv_window.add_dimensions();
+ conv_window_col->set_size(2);
+ conv_window_col->set_padding_high(1);
+
+ ConvolutionDimensionNumbers conv_dnums;
+ conv_dnums.set_batch_dimension(0);
+ conv_dnums.set_feature_dimension(1);
+ conv_dnums.add_spatial_dimensions(2);
+ conv_dnums.add_spatial_dimensions(3);
+ conv_dnums.set_kernel_output_feature_dimension(0);
+ conv_dnums.set_kernel_input_feature_dimension(1);
+ conv_dnums.add_kernel_spatial_dimensions(2);
+ conv_dnums.add_kernel_spatial_dimensions(3);
+
+ auto conv = builder.AddInstruction(
+ HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}),
+ input, filter, conv_window, conv_dnums));
+ auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0}));
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ module->AddEntryComputation(builder.Build());
+ EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, GetTupleElementFused) {
+ HloComputation::Builder builder(TestName());
+ Shape data_shape = ShapeUtil::MakeShape(F32, {8});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ auto gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, param, 0));
+ auto gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape, param, 1));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, gte0, gte1));
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(HloOpcode::kFusion, root->opcode());
+ HloInstruction* fused_root = root->fused_expression_root();
+ EXPECT_EQ(HloOpcode::kAdd, fused_root->opcode());
+ // Check that operands of 'fused_root' are GTE.
+ EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(0)->opcode());
+ EXPECT_EQ(HloOpcode::kGetTupleElement, fused_root->operand(1)->opcode());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
new file mode 100644
index 0000000000..0821fb01ab
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -0,0 +1,200 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+
+#include <algorithm>
+
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// Return whether the given shape is a matrix with no padding.
+bool IsRank2WithNoPadding(const Shape& shape) {
+ return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
+}
+
+// In a gemm operation where output = lhs * rhs, check whether the given shapes
+// are valid for the operation.
+bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
+ const Shape& output_shape) {
+ // The inputs and the output must
+ // 1) be matrices with no padding and a non-zero number of elements,
+ // 2) have an allowed element type.
+ bool type_is_allowed = (output_shape.element_type() == F32 ||
+ output_shape.element_type() == F64);
+ return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
+ IsRank2WithNoPadding(rhs_shape) &&
+ IsRank2WithNoPadding(output_shape) &&
+ !ShapeUtil::HasZeroElements(lhs_shape) &&
+ !ShapeUtil::HasZeroElements(rhs_shape);
+}
+} // namespace
+
+bool ImplementedAsGemm(const HloInstruction& hlo) {
+ // For certain types of Dot, we can call pre-canned BLAS gemm.
+ if (hlo.opcode() == HloOpcode::kDot) {
+ const Shape& lhs_shape = hlo.operand(0)->shape();
+ const Shape& rhs_shape = hlo.operand(1)->shape();
+
+ // If gemm can accept the operand shapes, use it rather than a custom
+ // kernel.
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
+ // The size of the reduction dimension should match. The shape inference
+ // guarantees this invariant, so the check here is for programming
+ // errors.
+ CHECK_EQ(lhs_shape.dimensions(1), rhs_shape.dimensions(0));
+ return true;
+ }
+ }
+
+ if (hlo.opcode() == HloOpcode::kFusion &&
+ hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot &&
+ hlo.fused_expression_root()->opcode() == HloOpcode::kDot) {
+ return true;
+ }
+
+ return false;
+}
+
+bool ImplementedAsDnnConvolution(const HloInstruction& hlo) {
+ // Forward convolution.
+ if (hlo.opcode() == HloOpcode::kConvolution) {
+ const ConvolutionDimensionNumbers& dnums =
+ hlo.convolution_dimension_numbers();
+ // Only 2D convolutions are implemented.
+ // TODO(b/32873825): add support for 3D convolutions using CuDNN.
+ if (dnums.spatial_dimensions_size() != 2) {
+ return false;
+ }
+ // CuDNN does not accept zero-element arguments
+ if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) ||
+ ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) {
+ return false;
+ }
+
+ return true;
+ }
+
+ // Backward convolution.
+ if (hlo.opcode() == HloOpcode::kFusion &&
+ (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter ||
+ hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) {
+ return true;
+ }
+
+ return false;
+}
+
+bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
+ return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo);
+}
+
+bool IsReductionToVector(const HloInstruction& reduce) {
+ if (HloOpcode::kReduce != reduce.opcode()) {
+ return false;
+ }
+ const HloInstruction* input = reduce.operand(0);
+ return ShapeUtil::Rank(input->shape()) > 1 &&
+ ShapeUtil::Rank(reduce.shape()) == 1;
+}
+
+// This emits a device-side call to
+// "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
+// http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
+llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
+ tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+ llvm::IRBuilder<>* builder) {
+ std::vector<llvm::Type*> argument_types;
+ for (auto argument : arguments) {
+ argument_types.push_back(argument->getType());
+ }
+ auto* arguments_type = llvm::StructType::create(argument_types);
+ llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ builder->CreateStore(
+ arguments[i],
+ builder->CreateGEP(arguments_ptr,
+ {builder->getInt64(0), builder->getInt32(i)}));
+ }
+ return builder->CreateCall(
+ builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
+ "vprintf",
+ llvm::FunctionType::get(builder->getInt32Ty(),
+ {builder->getInt8Ty()->getPointerTo(),
+ arguments_type->getPointerTo()},
+ /*isVarArg=*/false)),
+ {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
+ arguments_ptr});
+}
+
+llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset,
+ llvm::IRBuilder<>* builder) {
+ int bit_width = value->getType()->getPrimitiveSizeInBits();
+
+ // Special case for efficiency
+ if (value->getType()->isFloatTy() && bit_width == 32) {
+ return llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_shfl_down_f32,
+ {value, offset, builder->getInt32(kWarpSize - 1)}, {}, builder);
+ }
+
+ // We must split values wider than 32 bits as the "shfl" instruction operates
+ // on 32-bit values.
+ int num_segments = CeilOfRatio(bit_width, 32);
+ llvm::Value* x = builder->CreateBitCast(
+ builder->CreateZExt(
+ builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
+ builder->getIntNTy(32 * num_segments)),
+ llvm::VectorType::get(builder->getInt32Ty(), num_segments));
+ for (int i = 0; i < num_segments; ++i) {
+ x = builder->CreateInsertElement(
+ x,
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_shfl_down_i32,
+ {builder->CreateExtractElement(x, i),
+ offset, builder->getInt32(kWarpSize - 1)},
+ {}, builder),
+ i);
+ }
+ return builder->CreateBitCast(
+ builder->CreateTrunc(
+ builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
+ builder->getIntNTy(bit_width)),
+ value->getType());
+}
+
+const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo) {
+ while (hlo->opcode() == HloOpcode::kGetTupleElement) {
+ hlo = hlo->operand(0);
+ }
+ return hlo;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
new file mode 100644
index 0000000000..4d3e9b10b2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -0,0 +1,72 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
+
+#include <utility>
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+namespace gpu {
+
+const int64 kWarpSize = 32;
+
+// Precondition: "hlo" is an operand of a Dot instruction.
+//
+// Returns whether "hlo" is foldable to its user.
+bool IsOperandFoldableToDot(const HloInstruction& hlo);
+
+// Returns true if GpuCompiler can fold any operands of "dot" into "dot" for
+// better performance.
+bool CanFoldOperandsIntoDot(const HloInstruction& dot);
+
+// Returns true if `hlo` will be implemented as a call to BLAS gemm.
+bool ImplementedAsGemm(const HloInstruction& hlo);
+
+// Returns true if `hlo` will be implemented as a call to cuDNN convolution.
+bool ImplementedAsDnnConvolution(const HloInstruction& hlo);
+
+// Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
+// or cuDNN convolution.
+bool ImplementedAsLibraryCall(const HloInstruction& hlo);
+
+bool IsReductionToVector(const HloInstruction& reduce);
+
+// Emits call to "vprintf" with given format and arguments.
+llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
+ tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
+ llvm::IRBuilder<>* builder);
+
+// Emits code to shuffle data between threads of a warp. This has the same
+// semantics as the PTX "shfl.down" instruction [0] but works for values of any
+// size. The last operand of the emitted "shfl" is `kWarpSize - 1`.
+//
+// [0]
+// http://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl
+llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset,
+ llvm::IRBuilder<>* builder);
+
+// Resolves GetTupleElement instruction operands starting with 'hlo'.
+// Returns the first ancestor instruction which is not a GetTupleElement.
+const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
new file mode 100644
index 0000000000..ee1027b8cc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -0,0 +1,645 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/platform/logging.h"
+// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Constants.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+using llvm_ir::SetToFirstInsertPoint;
+
+namespace gpu {
+
+IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config,
+ IrEmitterContext* ir_emitter_context, bool is_nested)
+ : ir_emitter_context_(ir_emitter_context),
+ ir_builder_(ir_emitter_context->llvm_module()->getContext()),
+ bindings_(ir_emitter_context->hlo_module(),
+ &ir_emitter_context->buffer_assignment(),
+ &ir_emitter_context->temp_buffer_offsets(), &ir_builder_,
+ is_nested),
+ hlo_module_config_(hlo_module_config) {
+ llvm::FastMathFlags fast_math_flags;
+ llvm_ir::SetFastMathFlags(&fast_math_flags);
+ ir_builder_.setFastMathFlags(fast_math_flags);
+}
+
+Status IrEmitter::DefaultAction(HloInstruction* hlo) {
+ ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
+ for (const HloInstruction* operand : hlo->operands()) {
+ operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
+ return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_);
+ };
+ }
+ return EmitTargetElementLoop(
+ *hlo, GpuElementalIrEmitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &ir_builder_, GetNestedComputer())
+ .MakeElementGenerator(hlo, operand_to_generator));
+}
+
+Status IrEmitter::HandleConstant(HloInstruction* constant,
+ const Literal& literal) {
+ llvm::Constant* initializer =
+ llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
+ llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
+ *ir_emitter_context_->llvm_module(), initializer->getType(),
+ /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer,
+ /*Name=*/"");
+ VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl
+ << " emitted_value: " << llvm_ir::DumpToString(*global_for_const)
+ << std::endl
+ << " its type: "
+ << llvm_ir::DumpToString(*global_for_const->getType());
+ bindings_.BindHloToIrValue(*constant, global_for_const);
+ return Status::OK();
+}
+
+Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
+ VLOG(2) << "HandleBitcast: " << bitcast->ToString();
+ const HloInstruction* operand = bitcast->operand(0);
+ // Bitcast is a no-op, but we still want to bind it to an llvm::Value
+ // sometimes, e.g., when it's operand is a constant or a bitcast of a
+ // constant.
+ if (bindings_.BoundToIrValue(*operand)) {
+ bindings_.BindHloToIrValue(*bitcast, bindings_.GetBasePointer(*operand));
+ }
+ return Status::OK();
+}
+
+Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) {
+ CHECK(bindings_.BoundToIrValue(*operand));
+ bindings_.BindHloToIrValue(
+ *get_tuple_element,
+ llvm_ir::EmitGetTupleElement(
+ get_tuple_element->shape(), get_tuple_element->tuple_index(),
+ // TODO(b/26344050): tighten the alignment here
+ // based on the real element type.
+ /*alignment=*/1, GetBasePointer(*operand), &ir_builder_));
+ return Status::OK();
+}
+
+Status IrEmitter::HandleSort(HloInstruction* sort,
+ HloInstruction* operand_instruction) {
+ // TODO(b/26783907): Implement sort on GPU.
+ return Unimplemented("sort");
+}
+
+Status IrEmitter::HandleSend(HloInstruction* send) {
+ return Unimplemented("Send is not implemented on GPU");
+}
+
+Status IrEmitter::HandleRecv(HloInstruction* recv) {
+ return Unimplemented("Recv is not implemented on GPU");
+}
+
+Status IrEmitter::HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ std::vector<llvm::Value*> base_ptrs;
+ for (const HloInstruction* operand : operands) {
+ base_ptrs.push_back(GetBasePointer(*operand));
+ }
+ llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_);
+ return Status::OK();
+}
+
+Status IrEmitter::EmitCallToNestedComputation(
+ const HloComputation& nested_computation,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output) {
+ TF_RET_CHECK(nested_computation.num_parameters() > 0);
+ llvm::Function*& emitted_function =
+ computation_to_ir_function_[&nested_computation];
+ if (emitted_function == nullptr) {
+ IrEmitterNested ir_emitter_nested(hlo_module_config_, nested_computation,
+ ir_emitter_context_);
+ TF_RETURN_IF_ERROR(
+ nested_computation.root_instruction()->Accept(&ir_emitter_nested));
+ emitted_function = ir_emitter_nested.GetEmittedFunction();
+ }
+
+ std::vector<llvm::Value*> arguments(operands.begin(), operands.end());
+ arguments.push_back(output);
+ arguments.push_back(bindings_.GetTempBufferBase());
+ ir_builder_.CreateCall(emitted_function, arguments);
+
+ return Status::OK();
+}
+
+bool IrEmitter::MaybeEmitSpecialAtomicOperation(
+ const HloComputation& computation, llvm::Value* output_address,
+ llvm::Value* source_address) {
+ CHECK_EQ(2, computation.num_parameters());
+
+ if (computation.instruction_count() != 3) {
+ // We special-case only computations with one computing instruction for now.
+ // Such computation has exactly three instructions given it has two
+ // parameters.
+ return false;
+ }
+
+ HloOpcode root_opcode = computation.root_instruction()->opcode();
+ PrimitiveType element_type =
+ computation.root_instruction()->shape().element_type();
+ llvm::Value* source = ir_builder_.CreateLoad(source_address, "source");
+ if (root_opcode == HloOpcode::kAdd) {
+ // NVPTX supports atomicAdd on F32 and integer types.
+ if (element_type == F32) {
+ // F32 + F32
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_atomic_load_add_f32,
+ {output_address, source},
+ {output_address->getType()}, &ir_builder_);
+ return true;
+ }
+ if (primitive_util::IsIntegralType(element_type)) {
+ // integral + integral
+ ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Add, output_address,
+ source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
+ return true;
+ }
+ }
+
+ // NVPTX supports atomicMax and atomicMin on only integer types.
+ if (root_opcode == HloOpcode::kMaximum &&
+ primitive_util::IsIntegralType(element_type)) {
+ // min(integral, integral)
+ ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Max, output_address,
+ source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
+ return true;
+ }
+
+ if (root_opcode == HloOpcode::kMinimum &&
+ primitive_util::IsIntegralType(element_type)) {
+ // max(integral, integral)
+ ir_builder_.CreateAtomicRMW(llvm::AtomicRMWInst::Min, output_address,
+ source,
+ llvm::AtomicOrdering::SequentiallyConsistent);
+ return true;
+ }
+
+ return false;
+}
+
+Status IrEmitter::EmitAtomicOperationForNestedComputation(
+ const HloComputation& computation, llvm::Value* output_address,
+ llvm::Value* source_address) {
+ if (computation.num_parameters() != 2) {
+ // TODO(b/30258929): We only accept binary computations so far.
+ return Unimplemented(
+ "We only support atomic functions with exactly two parameters, but "
+ "computation %s has %lld.",
+ computation.name().c_str(), computation.num_parameters());
+ }
+
+ if (MaybeEmitSpecialAtomicOperation(computation, output_address,
+ source_address)) {
+ return Status::OK();
+ }
+
+ // Other binary computations can be made atomic as following (labels are basic
+ // block names used in the IR emitting code later).
+ //
+ // atomic_op_loop_preheader:
+ // ...
+ // source = *source_address;
+ // old_output = *output_address;
+ // do {
+ // atomic_op_loop_body_entry:
+ // new_output = computation(old_output, source);
+ // (old_output, success) =
+ // atomicCAS(output_address, old_output, new_output);
+ // } while (!success);
+ //
+ // atomic_op_loop_exit:
+ // ...
+ //
+ // TODO(jingyue): Consider encapsulate the logic of emitting control flow to
+ // something similar to llvm_ir::ForLoop.
+ //
+ // Emit preparation code to the preheader.
+ llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock();
+ llvm::Type* element_ir_type =
+ output_address->getType()->getPointerElementType();
+ // old_output = *output_address;
+ llvm::Value* old_output_location = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr, "old_output_location");
+ ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"),
+ old_output_location);
+ llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock(
+ ir_builder_.GetInsertPoint(), "atomic_op_loop_exit");
+
+ // Emit the body of the loop that repeatedly invokes atomicCAS.
+ llvm::BasicBlock* loop_body_bb =
+ llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body",
+ ir_builder_.GetInsertBlock()->getParent());
+ ir_builder_.SetInsertPoint(loop_body_bb);
+ // Change preheader's successor from loop_exit_bb to loop_body_bb.
+ loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb);
+ // new_output = computation(old_output, source);
+ llvm::Value* new_output_location = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr, "new_output_location");
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ computation, {old_output_location, source_address}, new_output_location));
+
+ // (old_output, success) = atomicCAS(output_address, old_output, new_output);
+ llvm::Type* element_int_ir_type =
+ ir_builder_.getIntNTy(element_ir_type->getScalarSizeInBits());
+ // cmpxchg accetps integer only, so we bitcast the operands (old_output and
+ // new_output) to integers of the same bit width, and bitcast the result
+ // back to the original element type.
+ llvm::Value* old_output =
+ ir_builder_.CreateLoad(old_output_location, "old_output");
+ llvm::Value* new_output =
+ ir_builder_.CreateLoad(new_output_location, "new_output");
+ llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg(
+ ir_builder_.CreateBitCast(output_address,
+ element_int_ir_type->getPointerTo()),
+ ir_builder_.CreateBitCast(old_output, element_int_ir_type),
+ ir_builder_.CreateBitCast(new_output, element_int_ir_type),
+ llvm::AtomicOrdering::SequentiallyConsistent,
+ llvm::AtomicOrdering::SequentiallyConsistent);
+ // cmpxchg returns a pair. The first element is the original value at
+ // output_address and the second element is whether the swap is successful.
+ ir_builder_.CreateStore(
+ ir_builder_.CreateBitCast(
+ ir_builder_.CreateExtractValue(ret_value, 0, "old_output"),
+ element_ir_type),
+ old_output_location);
+ ir_builder_.CreateCondBr(
+ ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb,
+ loop_body_bb);
+
+ // Restore the insertion point to the exit basic block so that the caller of
+ // this method can continue emitting code to the right place.
+ SetToFirstInsertPoint(loop_exit_bb, &ir_builder_);
+ return Status::OK();
+}
+
+Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ TF_RET_CHECK(pred->shape().element_type() == PRED);
+
+ if (ShapeUtil::IsTuple(select->shape())) {
+ llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred),
+ GetBasePointer(*on_true),
+ GetBasePointer(*on_false), &ir_builder_);
+ return Status::OK();
+ }
+
+ // We must not call the subclass `DefaultAction` method, lest its
+ // `HandleSelect` call `IrEmitter::HandleSelect` and its `DefaultAction`
+ // assume no handler has already been called.
+ return IrEmitter::DefaultAction(select);
+}
+
+Status IrEmitter::HandleDot(HloInstruction* dot,
+ HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction) {
+ const llvm_ir::IrArray& target_array = GetIrArray(*dot);
+ const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction);
+ const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction);
+
+ const Shape& lhs_shape = lhs_instruction->shape();
+ const Shape& rhs_shape = rhs_instruction->shape();
+
+ if (ShapeUtil::IsScalar(lhs_shape) && ShapeUtil::IsScalar(rhs_shape)) {
+ // If the operands are scalar, don't emit any loops.
+ llvm::Value* lhs_value =
+ lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
+ llvm::Value* rhs_value =
+ rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_);
+ llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value);
+ target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_);
+ return Status::OK();
+ }
+
+ // "Scalar dot non-scalar" or "non-scalar dot scalar" is invalid. See
+ // the semantics of Dot in the XLA documentation for details.
+ TF_RET_CHECK(!ShapeUtil::IsScalar(lhs_shape) &&
+ !ShapeUtil::IsScalar(rhs_shape));
+
+ // Reduce along the last dimension of the LHS and the second-to-last dimension
+ // of the RHS. Vectors are a special case where the reduction dimension is 0
+ // for both LHS and RHS. This results in a vector dot product producing a
+ // scalar.
+ const int64 lhs_reduction_dimension =
+ ShapeUtil::GetDimensionNumber(lhs_shape, -1);
+ const int64 rhs_reduction_dimension =
+ ShapeUtil::Rank(rhs_shape) >= 2
+ ? ShapeUtil::GetDimensionNumber(rhs_shape, -2)
+ : 0;
+
+ // Verify the reduction dimension in the two operands are the same size.
+ TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
+ rhs_shape.dimensions(rhs_reduction_dimension));
+
+ // Create loop nests which loop through the LHS operand dimensions and the RHS
+ // operand dimensions. The reduction dimension of the LHS and RHS are handled
+ // in a separate innermost loop which performs the sum of products.
+ llvm_ir::ForLoopNest loop_nest(&ir_builder_);
+ llvm_ir::IrArray::Index lhs_index = EmitOperandArrayLoopNest(
+ lhs_array, lhs_reduction_dimension, "lhs", &loop_nest);
+ llvm_ir::IrArray::Index rhs_index = EmitOperandArrayLoopNest(
+ rhs_array, rhs_reduction_dimension, "rhs", &loop_nest);
+
+ // Create the reduction loop which does the sum of products reduction.
+ std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/lhs_shape.dimensions(lhs_reduction_dimension),
+ /*suffix=*/"reduction");
+
+ // The final entry in the rhs and lhs indexes is the indvar of the reduction
+ // loop.
+ lhs_index[lhs_reduction_dimension] = reduction_loop->GetIndVarValue();
+ rhs_index[rhs_reduction_dimension] = reduction_loop->GetIndVarValue();
+
+ // For computing the sum of products we alloca a single location to store the
+ // dot product result as we accumulate it within the reduction loop. After the
+ // reduction loop we load the result and store into the output array.
+ llvm::Type* accum_type = target_array.GetElementLlvmType();
+ llvm::Value* accum_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ accum_type, // The pointee type of the alloca instruction.
+ "accum_address", // The name of the alloca instuction.
+ &ir_builder_);
+
+ // Initialize the accumulator in the preheader to zero.
+ new llvm::StoreInst(
+ llvm::ConstantFP::get(accum_type, 0.0), // The value stored.
+ accum_address, // The address.
+ reduction_loop->GetPreheaderBasicBlock()
+ ->getTerminator()); // The instruction this store is inserted before.
+
+ // Emit the body of the reduction loop:
+ // accum = *accum_address
+ // updated_accum = accum + lhs_element * rhs_element
+ // *accum_address = updated_accum
+ TF_RET_CHECK(!reduction_loop->GetBodyBasicBlock()->empty());
+ ir_builder_.SetInsertPoint(
+ &*reduction_loop->GetBodyBasicBlock()->getFirstInsertionPt());
+ llvm::Value* lhs_element =
+ lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_);
+ llvm::Value* rhs_element =
+ rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_);
+ llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element);
+ llvm::Value* accum = ir_builder_.CreateLoad(accum_address);
+ llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product);
+ ir_builder_.CreateStore(updated_accum, accum_address);
+
+ // After the reduction loop exits, store the accumulator into the target
+ // address. The index into the target address is the concatenation of the rhs
+ // and lhs indexes with the reduction dimensions removed. The terms from the
+ // rhs index are the lower dimensions in the index so we add them first.
+ llvm_ir::IrArray::Index target_index;
+ for (int dimension = 0; dimension < lhs_index.size(); ++dimension) {
+ if (dimension != lhs_reduction_dimension) {
+ target_index.push_back(lhs_index[dimension]);
+ }
+ }
+ for (int dimension = 0; dimension < rhs_index.size(); ++dimension) {
+ if (dimension != rhs_reduction_dimension) {
+ target_index.push_back(rhs_index[dimension]);
+ }
+ }
+ SetToFirstInsertPoint(reduction_loop->GetExitBasicBlock(), &ir_builder_);
+ target_array.EmitWriteArrayElement(
+ target_index,
+ ir_builder_.CreateLoad(
+ accum_address), // The value written to the target array.
+ &ir_builder_);
+
+ // Set the IR builder insert point to the exit basic block of the outer most
+ // loop. This ensures later instructions are inserted after this loop nest.
+ ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
+
+ return Status::OK();
+}
+
+Status IrEmitter::HandleConvolution(HloInstruction* convolution,
+ HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction,
+ const Window& window) {
+ if (ShapeUtil::HasZeroElements(convolution->shape())) {
+ // Emit no code for an empty output.
+ return Status::OK();
+ }
+ // TODO(b/31409998): Support convolution with dilation.
+ return Unimplemented(
+ "Hit a case for convolution that is not implemented on GPU.");
+}
+
+Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) {
+ // TODO(b/33011107): Support cross replica sum on GPU.
+ return Unimplemented(
+ "Cross replica sum not implemented on GPU. See b/33011107.");
+}
+
+Status IrEmitter::HandleParameter(HloInstruction* parameter) {
+ return Status::OK();
+}
+
+Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) {
+ return EmitTargetElementLoop(
+ *reduce,
+ [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ // Initialize an accumulator with init_value.
+ llvm::AllocaInst* accumulator_addr =
+ ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ reduce->shape().element_type(), &ir_builder_));
+ ir_builder_.CreateStore(
+ ir_builder_.CreateLoad(GetBasePointer(*init_value)),
+ accumulator_addr);
+
+ // The enclosing loops go over all the target elements. Now we have to
+ // compute the actual target element. For this, we build a new loop nest
+ // to iterate over all the reduction dimensions in the argument.
+ // AddLoopsForShapeOnDimensions will return an Index where induction
+ // Value*s are placed for each dimension in dimensions, and all the rest
+ // are nullptrs.
+ llvm_ir::ForLoopNest loops(&ir_builder_);
+ const llvm_ir::IrArray::Index reduced_dims_index =
+ loops.AddLoopsForShapeOnDimensions(arg->shape(), dimensions,
+ "reduction_dim");
+
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
+
+ // Build a full index for the input argument, using reduced_dims_index
+ // as the base. In reduced_dims_index only the reduction dimensions are
+ // filled in. We fill in the rest of the dimensions with induction
+ // Value*s taken from 'index' which iterates over the target array.
+ // See the high-level description in the XLA documentation for details.
+ llvm_ir::IrArray::Index input_index = reduced_dims_index;
+ llvm_ir::IrArray::Index::const_iterator it = index.begin();
+
+ for (int64 i = 0; i < input_index.size(); ++i) {
+ if (input_index[i] == nullptr) {
+ input_index[i] = *it++;
+ }
+ }
+ CHECK(index.end() == it);
+
+ // Apply the reduction function to the loaded value.
+ llvm::Value* input_address =
+ GetIrArray(*arg).EmitArrayElementAddress(input_index, &ir_builder_);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *function, {accumulator_addr, input_address}, accumulator_addr));
+
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
+ return ir_builder_.CreateLoad(accumulator_addr);
+ });
+}
+
+Status IrEmitter::HandleFusion(HloInstruction* fusion) {
+ // kFusion for library calls should be handled by
+ // IrEmitterUnnested::HandleFusion.
+ CHECK(HloInstruction::FusionKind::kLoop == fusion->fusion_kind());
+
+ std::vector<llvm_ir::IrArray> parameter_arrays;
+ for (HloInstruction* operand : fusion->operands()) {
+ parameter_arrays.push_back(GetIrArray(*operand));
+ }
+ GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &ir_builder_, GetNestedComputer());
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter));
+
+ return EmitTargetElementLoop(*fusion, fused_emitter.GetRootGenerator());
+}
+
+Status IrEmitter::HandleCall(
+ HloInstruction* call, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) {
+ std::vector<llvm::Value*> operand_addresses;
+ for (HloInstruction* operand : operands) {
+ operand_addresses.push_back(GetBasePointer(*operand));
+ }
+ return EmitCallToNestedComputation(*computation, operand_addresses,
+ GetBasePointer(*call));
+}
+
+Status IrEmitter::HandleCustomCall(
+ HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) {
+ return Unimplemented("custom-call");
+}
+
+Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
+ return Unimplemented("Infeed is not supported on GPU (b/30467474)");
+}
+
+Status IrEmitter::HandleRng(HloInstruction* random,
+ RandomDistribution /*distribution*/) {
+ ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator;
+ for (const HloInstruction* operand : random->operands()) {
+ operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) {
+ return GetIrArray(*operand).EmitReadArrayElement(index, &ir_builder_);
+ };
+ }
+ // Emits a single-threaded loop because the loop body generated by the element
+ // generator for Rng can't be parallelized (b/32333178).
+ return llvm_ir::LoopEmitter(
+ GpuElementalIrEmitter(hlo_module_config_,
+ ir_emitter_context_->llvm_module(),
+ &ir_builder_, GetNestedComputer())
+ .MakeElementGenerator(random, operand_to_generator),
+ GetIrArray(*random), &ir_builder_)
+ .EmitLoop();
+}
+
+llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
+ const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
+ tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) {
+ // Prepares the dimension list we will use to emit the loop nest. Outermost
+ // loops are added first. Add loops in major-to-minor order, and skip the
+ // reduction dimension.
+ std::vector<int64> dimensions;
+ const Shape& shape = operand_array.GetShape();
+ for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) {
+ int64 dimension = shape.layout().minor_to_major(i);
+ if (dimension != reduction_dimension) {
+ dimensions.push_back(dimension);
+ }
+ }
+
+ // Create loop nest with one for-loop for each dimension of the
+ // output.
+ llvm_ir::IrArray::Index index =
+ loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
+ // Verify every dimension except the reduction dimension was set in the index.
+ for (int dimension = 0; dimension < index.size(); ++dimension) {
+ if (dimension == reduction_dimension) {
+ DCHECK_EQ(nullptr, index[dimension]);
+ } else {
+ DCHECK_NE(nullptr, index[dimension]);
+ }
+ }
+ return index;
+}
+
+StatusOr<llvm::Value*> IrEmitter::ComputeNestedElement(
+ const HloComputation& computation,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements) {
+ llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(
+ computation.root_instruction()->shape().element_type(), &ir_builder_),
+ "return_buffer", &ir_builder_);
+ std::vector<llvm::Value*> parameter_buffers;
+ for (llvm::Value* parameter_element : parameter_elements) {
+ parameter_buffers.push_back(llvm_ir::EmitAllocaAtFunctionEntry(
+ parameter_element->getType(), "parameter_buffer", &ir_builder_));
+ ir_builder_.CreateStore(parameter_element, parameter_buffers.back());
+ }
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(computation, parameter_buffers,
+ return_buffer));
+ return ir_builder_.CreateLoad(return_buffer);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
new file mode 100644
index 0000000000..0764c61ede
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -0,0 +1,405 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// An XLA HLO graph may contain multiple computations. These computations
+// fall into two types, nested and unnested. We translate each nested
+// computation (e.g. the computation operand of a Map operator) to a device
+// function. For each unnested computation composed of top-level
+// HloInstructions, we generate a CUDA kernel for each HloInstruction.
+//
+// This file declares classes that translate an XLA HLO graph to LLVM IR for
+// GPUs. IrEmitterNested emits LLVM IR for nested computations, and
+// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR
+// for each individual HloInstruction is largely the same between these two
+// classes. Therefore, we implement the common logic in the Handle* functions in
+// the superclass IrEmitter.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+// This class is the top-level API for the XLA HLO --> LLVM IR compiler.
+// It implements the DfsHloVisitor interface and emits an LLVM IR program that
+// implements the input HLO graph.
+//
+// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in
+// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault`
+// calls `T::DefaultAction`.
+class IrEmitter : public DfsHloVisitorWithDefault {
+ public:
+ IrEmitter(const IrEmitter&) = delete;
+ IrEmitter& operator=(const IrEmitter&) = delete;
+
+ // The following methods implement the DfsHloVisitorWithDefault interface.
+ Status DefaultAction(HloInstruction* hlo) override;
+ Status HandleConstant(HloInstruction* constant,
+ const Literal& literal) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+ Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
+ HloInstruction* rhs, const Window& window) override;
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleInfeed(HloInstruction* infeed) override;
+ Status HandleSort(HloInstruction* sort, HloInstruction* operand) override;
+ Status HandleSend(HloInstruction* send) override;
+ Status HandleRecv(HloInstruction* recv) override;
+ Status HandleParameter(HloInstruction* parameter) override;
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) override;
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleCall(HloInstruction* call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) override;
+ Status HandleCustomCall(HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) override;
+ Status HandleRng(HloInstruction* random,
+ RandomDistribution /*distribution*/) override;
+
+ Status FinishVisit(HloInstruction* root) override { return Status::OK(); }
+
+ protected:
+ // Constructs an IrEmitter with the given IrEmitter context.
+ // ir_emitter_context is owned by the caller and should outlive the IrEmitter
+ // object.
+ explicit IrEmitter(const HloModuleConfig& hlo_module_config,
+ IrEmitterContext* ir_emitter_context, bool is_nested);
+
+ // A convenient helper for calling HloToIrBindings::GetIrArray.
+ llvm_ir::IrArray GetIrArray(const HloInstruction& inst) {
+ return bindings_.GetIrArray(inst);
+ }
+ // A convenient helper for calling HloToIrBindings::GetBasePointer.
+ llvm::Value* GetBasePointer(const HloInstruction& inst) const {
+ return bindings_.GetBasePointer(inst);
+ }
+ // A convenient helper for calling BufferAssignment::GetAllocationIndex.
+ BufferAllocation::Index GetAllocationIndex(const HloInstruction& hlo) const {
+ return ir_emitter_context_->buffer_assignment()
+ .GetUniqueTopLevelAllocation(&hlo)
+ .ConsumeValueOrDie()
+ ->index();
+ }
+
+ // Emit a singlethreaded or multithreaded loop that computes every element in
+ // the result of the given HLO instruction. This produces a series of nested
+ // loops (e.g. one for each dimension of the `hlo`'s shape). The body of the
+ // inner-most loop is provided by the body_emitter function.
+ virtual Status EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& body_emitter) = 0;
+
+ // Emits a call in IR to the given nested computation with the given operands
+ // and output. If no IR function has been previously emitted for the
+ // computation, also emits such a function.
+ Status EmitCallToNestedComputation(
+ const HloComputation& nested_computation,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands, llvm::Value* output);
+
+ // Emits an atomic operation that implements `nested_computation` in the
+ // sequentially consistent memory model. `output_address` and `source_address`
+ // are the arguments of the nested computation. For example,
+ // atomicAdd(output_address, *source_address).
+ Status EmitAtomicOperationForNestedComputation(
+ const HloComputation& nested_computation, llvm::Value* output_address,
+ llvm::Value* source_address);
+
+ GpuElementalIrEmitter::NestedComputer GetNestedComputer() {
+ return std::bind(&IrEmitter::ComputeNestedElement, this,
+ std::placeholders::_1, std::placeholders::_2);
+ }
+
+ IrEmitterContext* ir_emitter_context_;
+
+ // The following fields track the IR emission state. According to LLVM memory
+ // management rules, their memory is owned by the module.
+ llvm::IRBuilder<> ir_builder_;
+
+ // Mapping from HLO to its underlying LLVM value.
+ HloToIrBindings bindings_;
+
+ // Hlo configuration data used during code generation.
+ const HloModuleConfig& hlo_module_config_;
+
+ private:
+ // Emits a series of nested loops for iterating over an operand array in the
+ // dot operation. Loops are constructed in major to minor dimension layout
+ // order. No loop is emitted for the given reduction_dimension. The function
+ // returns an IrArray index for the given operand_array containing the indvars
+ // of the loops. All dimensions of the index are filled except for the
+ // reduction dimension. name_suffix is the string to append to the names of
+ // LLVM constructs (eg, basic blocks) constructed by this method.
+ llvm_ir::IrArray::Index EmitOperandArrayLoopNest(
+ const llvm_ir::IrArray& operand_array, int64 reduction_dimension,
+ tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest);
+
+ // A helper method for EmitAtomicOperationForNestedComputation. Certain
+ // computations, such as floating-point addition and integer maximization, can
+ // be simply implemented using an LLVM atomic instruction. If "computation" is
+ // one of this kind, emits code to do that and returns true; otherwise,
+ // returns false.
+ bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation,
+ llvm::Value* output_address,
+ llvm::Value* source_address);
+
+ StatusOr<llvm::Value*> ComputeNestedElement(
+ const HloComputation& computation,
+ tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
+
+ // Emits an atomic operation that implements `nested_computation` in the
+ // sequentially consistent memory model. `output_address` and `source_address`
+ // are the arguments of the nested computation. For example,
+ // atomicAdd(output_address, *source_address).
+ StatusOr<llvm::Function*> EmitAtomicFunctionForNestedComputation(
+ const HloComputation& nested_computation, llvm::Type* element_ir_type);
+
+ // Map nested computations to emitted IR functions. This serves as a cache so
+ // that IrEmitter does not emit multiple functions for the same
+ // HloComputation.
+ std::map<const HloComputation*, llvm::Function*> computation_to_ir_function_;
+};
+
+// Emits LLVM IR for unnested computations. Each HloInstruction is translated to
+// a separate CUDA kernel. These kernels are inserted into the resultant module
+// sorted in reverse postorder of the XLA HLO graph.
+class IrEmitterUnnested : public IrEmitter {
+ public:
+ IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
+ const HloComputation* hlo_computation,
+ bool has_hybrid_result,
+ IrEmitterContext* ir_emitter_context);
+ IrEmitterUnnested(const IrEmitterUnnested&) = delete;
+ IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
+
+ // Transfers the ownship of thunk_sequence_ out.
+ std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
+ return std::move(thunk_sequence_);
+ }
+
+ Status DefaultAction(HloInstruction* hlo) override;
+
+ // IrEmitterUnnested handles the following instructions differently from
+ // IrEmitter.
+ Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
+ HloInstruction* rhs, const Window& window) override;
+ Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) override;
+ Status HandleSelectAndScatter(HloInstruction* instruction) override;
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleWhile(HloInstruction* xla_while, HloInstruction* init,
+ HloComputation* condition, HloComputation* body) override;
+ Status HandleRng(HloInstruction* random,
+ RandomDistribution distribution) override;
+ Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) override;
+
+ Status EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& body_emitter) override;
+
+ // Same as `EmitTargetElementLoop`, but in given `thunk` rather than
+ // `LastThunk()`.
+ Status EmitTargetElementLoopInThunk(
+ const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter,
+ KernelThunk* thunk);
+
+ private:
+ // Builds the appropriate thunk for the instruction hlo and returns the owning
+ // pointer to it. The caller needs to make sure `inst` outlives the lifetime
+ // of the returned Thunk object.
+ std::unique_ptr<Thunk> BuildThunk(const HloInstruction* hlo);
+
+ // Builds the prototype of the IR kernel for `inst` and adds it to the module.
+ llvm::Function* BuildKernelPrototype(
+ const HloInstruction& inst,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> escaped_hlos);
+
+ // Emits the base pointers for `hlo` and its operands. `io_hlos` will store
+ // all input/output HLOs among `hlo` and its operands.
+ llvm::Function* EmitBasePointersForHloAndItsOperands(
+ const HloInstruction& hlo, std::vector<const HloInstruction*>* io_hlos);
+
+ // EmitColumnReduction and EmitRowReduction emit code for column and row
+ // reduction of a matrix and/or 3D tensor. Row and column reduction have
+ // different memory access pattern, so for performance their implementations
+ // are significantly different.
+ //
+ // Emits code that reduces a matrix of shape [height x width] to a vector of
+ // [width]. Other parameters have the same meaning as those of
+ // `EmitReductionToVector`. Note that input shape might not be
+ // [height x width], but can be bitcast to [height x weight] with "height"
+ // being the major dimension.
+ Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce,
+ const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ HloComputation* reducer);
+
+ // Emits code that reduces a 3D tensor of shape [depth x height x width] to a
+ // vector of shape [height]. Other parameters have the same meaning as those
+ // of `EmitReductionToVector`. Note that input shape might not be
+ // [depth x height x width], but can be bitcast to [depth x height x weight]
+ // with "depth" being the most major dimension.
+ Status EmitRowReduction(int64 depth, int64 height, int64 width,
+ HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ HloComputation* reducer);
+
+ // Figures out whether `reduce` is a row or column reduction, and which
+ // dimensions to reduce, and calls either `EmitRowReduction` or
+ // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the
+ // input array, which is the operand of the Reduce instruction if unfused or
+ // of the Fusion instruction if fused. `input_gen` and `init_value_gen`
+ // generate elements of the input and the initial value. Other parameters mean
+ // the same as for `HandleReduce`.
+ //
+ // Prerequisite: `IsReductionToVector(*reduce)`
+ Status EmitReductionToVector(
+ HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reducer);
+
+ // Emits code to initialize buffer of `inst` in given `thunk`.
+ Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk);
+
+ // Returns a KernelThunk that invokes the kernel emitted for `inst`. The
+ // caller needs to make sure `inst` outlives the lifetime of the returned
+ // Thunk object.
+ std::unique_ptr<Thunk> BuildKernelThunk(const HloInstruction* inst);
+
+ // Returns a ConvolutionThunk that calls DNN to implement `inst`.
+ std::unique_ptr<Thunk> BuildConvolutionThunk(const HloInstruction* inst);
+
+ // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
+ // to make sure `inst` outlives the lifetime of the returned Thunk object.
+ std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
+
+ // Returns a CopyThunk that calls host-to-device cuMemcpy to implement `inst`.
+ std::unique_ptr<Thunk> BuildCopyThunk(const HloInstruction* inst);
+
+ // Returns a WhileThunk that invokes thunk sequences for 'condition' and
+ // 'body' sub-computations of while instruction 'hlo'.
+ std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
+
+ // Returns a ForThunk which executes 'loop_limit' invocations of a thunk
+ // sequence from the 'body' sub-computation of the while instruction 'hlo'.
+ std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
+ const int64 loop_limit);
+
+ Status Postprocess(HloInstruction* hlo) override;
+
+ // Returns the last generated thunk.
+ Thunk* LastThunk() const { return thunk_sequence_->back().get(); }
+
+ // The thunk sequence this IrEmitter generates for the input computation.
+ std::unique_ptr<ThunkSequence> thunk_sequence_;
+
+ // The HloComputation that this IrEmitter emits code for.
+ const HloComputation* hlo_computation_;
+
+ // Whether this computation will produce a hybrid result, that is the
+ // computation produces a ShapedBuffer.
+ bool has_hybrid_result_;
+};
+
+// Emits LLVM IR for a nested computation to the resultant function.
+class IrEmitterNested : public IrEmitter {
+ public:
+ // Constructs an LLVM IR emitter for a nested HLO computation. `function` is
+ // the containing IR function this emitter produces IR to. See
+ // IrEmitter::IrEmitter for the meanings of other arguments.
+ IrEmitterNested(const HloModuleConfig& hlo_module_config,
+ const HloComputation& nested_computation,
+ IrEmitterContext* ir_emitter_context);
+ IrEmitterNested(const IrEmitterNested&) = delete;
+ IrEmitterNested& operator=(const IrEmitterNested&) = delete;
+
+ // Overrides the default empty implementation. Binds the given instruction
+ // "parameter" with the parameter of the IR function.
+ Status HandleParameter(HloInstruction* parameter) override;
+
+ llvm::Function* GetEmittedFunction() const { return emitted_function_; }
+
+ Status EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& body_emitter) override;
+
+ private:
+ llvm::Function* EmitBasePointersForNestedComputation(
+ const HloComputation& nested_computation,
+ std::vector<const HloInstruction*>* io_hlos);
+
+ llvm::Function* emitted_function_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
new file mode 100644
index 0000000000..b204d9625c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
+
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// IrEmitterContext encapsulates common (mutable and immutable) data structures
+// used by both IrEmitterNested and IrEmitterUnnested, such as the buffer
+// assignment and the name uniquer.
+class IrEmitterContext {
+ public:
+ IrEmitterContext(const HloModule* hlo_module,
+ const BufferAssignment* buffer_assignment,
+ const TempBufferOffsets* temp_buffer_offsets,
+ const perftools::gputools::DeviceDescription* device_desc,
+ llvm::Module* llvm_module)
+ : hlo_module_(hlo_module),
+ buffer_assignment_(buffer_assignment),
+ temp_buffer_offsets_(temp_buffer_offsets),
+ device_desc_(device_desc),
+ llvm_module_(llvm_module) {}
+ // Disallow copy and assign.
+ IrEmitterContext(const IrEmitterContext&) = delete;
+ IrEmitterContext& operator=(const IrEmitterContext&) = delete;
+
+ // Simple accessors.
+ const HloModule& hlo_module() const { return *hlo_module_; }
+ const BufferAssignment& buffer_assignment() const {
+ return *buffer_assignment_;
+ }
+ const TempBufferOffsets& temp_buffer_offsets() const {
+ return *temp_buffer_offsets_;
+ }
+ const perftools::gputools::DeviceDescription& device_description() const {
+ return *device_desc_;
+ }
+ llvm::Module* llvm_module() { return llvm_module_; }
+ NameUniquer* name_uniquer() { return &name_uniquer_; }
+
+ private:
+ const HloModule* hlo_module_;
+ const BufferAssignment* buffer_assignment_;
+ const TempBufferOffsets* temp_buffer_offsets_;
+ const perftools::gputools::DeviceDescription* device_desc_;
+ llvm::Module* llvm_module_;
+ NameUniquer name_uniquer_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
new file mode 100644
index 0000000000..dc5e2d8f02
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+namespace gpu {
+
+IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
+ const HloComputation& nested_computation,
+ IrEmitterContext* ir_emitter_context)
+ : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) {
+ std::vector<const HloInstruction*> io_hlos;
+ emitted_function_ =
+ EmitBasePointersForNestedComputation(nested_computation, &io_hlos);
+}
+
+llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
+ const HloComputation& nested_computation,
+ std::vector<const HloInstruction*>* io_hlos) {
+ std::vector<llvm::Type*> argument_types;
+ std::vector<int64> argument_dereferenceable_bytes;
+ for (const HloInstruction* param :
+ nested_computation.parameter_instructions()) {
+ io_hlos->push_back(param);
+ const Shape& param_shape = param->shape();
+ argument_types.push_back(
+ llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo());
+ int64 param_size = llvm_ir::ByteSizeOf(
+ param_shape, ir_emitter_context_->llvm_module()->getDataLayout());
+ argument_dereferenceable_bytes.push_back(param_size);
+ }
+ {
+ const HloInstruction* root = nested_computation.root_instruction();
+ io_hlos->push_back(root);
+ const Shape& root_shape = root->shape();
+ argument_types.push_back(
+ llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo());
+ int64 root_size = llvm_ir::ByteSizeOf(
+ root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
+ argument_dereferenceable_bytes.push_back(root_size);
+ }
+ // The base pointer of the memory block for all pre-allocated temp buffers.
+ argument_types.push_back(ir_builder_.getInt8PtrTy());
+
+ llvm::FunctionType* function_type =
+ llvm::FunctionType::get(ir_builder_.getVoidTy(), argument_types, false);
+ llvm::Function* function = llvm::Function::Create(
+ function_type, // The function type.
+ llvm::GlobalValue::InternalLinkage, // The linkage type.
+ llvm_ir::AsStringRef(ir_emitter_context_->name_uniquer()->GetUniqueName(
+ llvm_ir::SanitizeIrName(
+ nested_computation.name()))), // The name of the function.
+ ir_emitter_context_->llvm_module()); // The parent LLVM module.
+ for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
+ ++arg_no) {
+ int64 arg_size = argument_dereferenceable_bytes[arg_no];
+ if (arg_size > 0) {
+ function->addDereferenceableAttr(arg_no + 1, arg_size);
+ }
+ }
+
+ llvm::BasicBlock* entry_bb =
+ llvm::BasicBlock::Create(function->getContext(), "entry", function);
+ // Emit a "return void" at entry_bb's end, and sets the insert point before
+ // that return instruction.
+ ir_builder_.SetInsertPoint(
+ llvm::ReturnInst::Create(function->getContext(), entry_bb));
+
+ std::vector<const HloInstruction*> non_io_hlos;
+ for (const auto& hlo : nested_computation.instructions()) {
+ if (hlo->opcode() != HloOpcode::kParameter &&
+ hlo.get() != nested_computation.root_instruction()) {
+ non_io_hlos.push_back(hlo.get());
+ }
+ }
+ bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
+ return function;
+}
+
+Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
+ return Status::OK();
+}
+
+Status IrEmitterNested::EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& element_generator) {
+ return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo), &ir_builder_)
+ .EmitLoop();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
new file mode 100644
index 0000000000..79a6443346
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -0,0 +1,1745 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/StringRef.h"
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
+#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// If a dimensions is smaller than this, untiled transposition may be more
+// efficient.
+const int64 kMinDimensionToTransposeTiled = 16;
+
+// Returns true if all paths from `hlo` to `root` contain only tuples. The
+// result of such an HloInstruction does not need to be materialized, when the
+// computation can have a hybrid result.
+bool ReachRootViaOnlyTuples(const HloInstruction& hlo,
+ const HloInstruction& root) {
+ if (hlo.opcode() != HloOpcode::kTuple) {
+ return false;
+ }
+
+ if (&hlo == &root) {
+ return true;
+ }
+
+ for (HloInstruction* user : hlo.users()) {
+ if (!ReachRootViaOnlyTuples(*user, root)) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+// If `hlo` is a Transpose, returns its operand; otherwise returns `hlo` itself.
+const HloInstruction* StripTranspose(const HloInstruction& hlo) {
+ if (hlo.IsRank2Transpose()) {
+ return hlo.operand(0);
+ }
+ return &hlo;
+}
+
+// Updates the launch dimensions in "thunk" and annotate the launch dimensions
+// of the corresponding IR kernel in "llvm_module".
+// Precondition: "thunk" must be a KernelThunk.
+void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
+ llvm::Module* llvm_module) {
+ CHECK(Thunk::Kind::kKernel == thunk->kind());
+ KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
+ kernel_thunk->SetLaunchDimensions(launch_dims);
+
+ // Add __launch_bounds__ to metadata. This limits registers per thread to
+ // avoid out-of-resources launching errors.
+ llvm::NamedMDNode* nvvm_annotations_node =
+ llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
+ llvm::Function* ir_kernel =
+ llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
+ llvm::LLVMContext& llvm_context = llvm_module->getContext();
+ llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
+ llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
+ launch_dims.threads_per_block());
+ nvvm_annotations_node->addOperand(llvm::MDNode::get(
+ llvm_context,
+ {llvm::ConstantAsMetadata::get(ir_kernel),
+ llvm::MDString::get(llvm_context, "maxntidx"),
+ llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
+}
+} // namespace
+
+IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
+ const HloComputation* hlo_computation,
+ bool has_hybrid_result,
+ IrEmitterContext* ir_emitter_context)
+ : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
+ hlo_computation_(hlo_computation),
+ has_hybrid_result_(has_hybrid_result) {
+ // Initialize thunk_sequence_ to an empty list of thunks.
+ thunk_sequence_.reset(new ThunkSequence());
+}
+
+Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
+ bindings_.UnbindAllLocalIrValues();
+ return DfsHloVisitor::Postprocess(hlo);
+}
+
+namespace {
+bool ImplementedAsMemcpy(const HloInstruction& hlo) {
+ // `hlo` needs to satisfy three conditions to be implemented as a
+ // host-to-device cuMemcpy.
+ //
+ // 1. `hlo` is a kCopy instruction.
+ // 2. `hlo`'s only operand is a kConstant instruction.
+ // 3. `hlo` and its operand have the same shape (thus the same layout too).
+ return hlo.opcode() == HloOpcode::kCopy &&
+ hlo.operand(0)->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape());
+}
+} // namespace
+
+llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
+ const HloInstruction& inst,
+ tensorflow::gtl::ArraySlice<const HloInstruction*> escaped_hlos) {
+ // Compute the kernel name. The opcode string may contain "-" which cannot be
+ // in a PTX function name, so sanitize the name before uniquifying it.
+ string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
+ llvm_ir::SanitizeIrName(inst.name()));
+
+ // Create the kernel and adds it to the module.
+ llvm::Module* module = ir_emitter_context_->llvm_module();
+ llvm::LLVMContext& context = module->getContext();
+ int num_escaped_hlos = escaped_hlos.size();
+ llvm::FunctionType* kernel_type = llvm::FunctionType::get(
+ llvm::Type::getVoidTy(context), // The type of function result.
+ std::vector<llvm::Type*>(num_escaped_hlos + 1,
+ ir_builder_.getInt8PtrTy()),
+ false); // Not a variadic argument function.
+ llvm::Function* kernel =
+ llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
+ kernel_name.c_str(), module);
+
+ // Add dereferenceable information to each of the escaped HLO parameters.
+ for (size_t arg_no = 0; arg_no < escaped_hlos.size(); ++arg_no) {
+ const HloInstruction* escaped_hlo = escaped_hlos[arg_no];
+ const Shape& escaped_hlo_shape = escaped_hlo->shape();
+ int64 escaped_hlo_size = llvm_ir::ByteSizeOf(
+ escaped_hlo_shape, ir_emitter_context_->llvm_module()->getDataLayout());
+ kernel->addDereferenceableAttr(arg_no + 1, escaped_hlo_size);
+ }
+
+ // The last argument is a pointer to the temporary buffer memory block.
+ // We know that it doesn't alias any of the escaped arguments (the inputs +
+ // the result). We also know how many bytes can be dereferenced in it.
+ const llvm::Argument& temp_buffer = kernel->getArgumentList().back();
+ int64 temp_buffer_size =
+ ir_emitter_context_->temp_buffer_offsets().TotalSizeInBytes();
+ int64 temp_buffer_arg_no = temp_buffer.getArgNo();
+ if (temp_buffer_size > 0) {
+ kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, temp_buffer_size);
+ }
+ kernel->setDoesNotAlias(temp_buffer_arg_no + 1);
+
+ // Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
+ // treats it as a CUDA kernel.
+ llvm::NamedMDNode* nvvm_annotations_node =
+ module->getOrInsertNamedMetadata("nvvm.annotations");
+ nvvm_annotations_node->addOperand(llvm::MDNode::get(
+ context, {llvm::ConstantAsMetadata::get(kernel),
+ llvm::MDString::get(context, "kernel"),
+ llvm::ConstantAsMetadata::get(ir_builder_.getInt32(1))}));
+
+ // Update the insert point to the entry basic block.
+ llvm::BasicBlock* entry_bb =
+ llvm::BasicBlock::Create(context,
+ "entry", // The name of the basic block.
+ kernel); // The parent/owner of "entry_bb".
+ // Emit a "return void" at entry_bb's end, and sets the insert point before
+ // that return instruction.
+ ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
+
+ return kernel;
+}
+
+Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
+ thunk_sequence_->emplace_back(BuildKernelThunk(hlo));
+ return IrEmitter::DefaultAction(hlo);
+}
+
+Status IrEmitterUnnested::HandleDot(HloInstruction* dot,
+ HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction) {
+ if (ImplementedAsGemm(*dot)) {
+ thunk_sequence_->emplace_back(BuildGemmThunk(dot));
+ return Status::OK();
+ }
+ thunk_sequence_->emplace_back(BuildKernelThunk(dot));
+ return IrEmitter::HandleDot(dot, lhs_instruction, rhs_instruction);
+}
+
+Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution,
+ HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction,
+ const Window& window) {
+ if (ImplementedAsDnnConvolution(*convolution)) {
+ thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution));
+ return Status::OK();
+ }
+ thunk_sequence_->emplace_back(BuildKernelThunk(convolution));
+ return IrEmitter::HandleConvolution(convolution, lhs_instruction,
+ rhs_instruction, window);
+}
+
+Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
+ HloInstruction* root = fusion->fused_expression_root();
+ // HandleFusion specializes reduction from a multi-dimensional array to a 1D
+ // array. The specialized version requires a initializer thunk that
+ // initializes the output array to the initial value of the reduce.
+ if (HloInstruction::FusionKind::kInput == fusion->fusion_kind()) {
+ switch (root->opcode()) {
+ case HloOpcode::kReduce: {
+ VLOG(3) << "Emitting fused reduction to vector: " << fusion->ToString();
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ thunks.emplace_back(BuildKernelThunk(fusion));
+ TF_RETURN_IF_ERROR(EmitInitializer(
+ fusion, static_cast<KernelThunk*>(thunks.back().get())));
+ bindings_.UnbindAllLocalIrValues();
+ thunks.emplace_back(BuildKernelThunk(fusion));
+ thunk_sequence_->emplace_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), fusion));
+ std::vector<llvm_ir::IrArray> parameter_arrays;
+ for (HloInstruction* operand : fusion->operands()) {
+ parameter_arrays.push_back(GetIrArray(*operand));
+ }
+ GpuElementalIrEmitter elemental_emitter(
+ hlo_module_config_, ir_emitter_context_->llvm_module(),
+ &ir_builder_, GetNestedComputer());
+ FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter);
+ TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
+
+ Shape input_shape = root->operand(0)->shape();
+ // EmitRedutionToVector requires the input shape to have a layout, but
+ // fused instructions don't have one. So we determine its layout from
+ // the fusion's operands. The choice of the layout only affects
+ // performance but not correctness.
+ auto choose_input_layout = [](
+ tensorflow::gtl::ArraySlice<const HloInstruction*> operands,
+ Shape* input_shape) {
+ // Prefer the layout of an operand whose shape is compatible with
+ // input_shape.
+ for (const HloInstruction* operand : operands) {
+ if (ShapeUtil::Compatible(*input_shape, operand->shape())) {
+ LayoutUtil::CopyLayoutBetweenShapes(operand->shape(),
+ input_shape);
+ return;
+ }
+ }
+ // If no operand has a compatible shape, prefer an operand that has
+ // the same rank at least.
+ for (const HloInstruction* operand : operands) {
+ if (ShapeUtil::Rank(*input_shape) ==
+ ShapeUtil::Rank(operand->shape())) {
+ // Do not use CopyLayoutBetweenShapes because input_shape and
+ // operand->shape() may be incompatible.
+ *input_shape->mutable_layout() = operand->shape().layout();
+ return;
+ }
+ }
+ // When all the above fails, which is rare, set the default layout.
+ LayoutUtil::SetToDefaultLayout(input_shape);
+ };
+ choose_input_layout(fusion->operands(), &input_shape);
+
+ return EmitReductionToVector(
+ root, input_shape, fused_emitter.GetGenerator(root->operand(0)),
+ fused_emitter.GetGenerator(root->operand(1)), root->dimensions(),
+ root->to_apply());
+ break;
+ }
+ default:
+ LOG(FATAL) << "Bad opcode for input fusion: "
+ << fusion->fused_expression_root()->opcode();
+ }
+ }
+ if (ImplementedAsGemm(*fusion)) {
+ thunk_sequence_->emplace_back(BuildGemmThunk(fusion));
+ return Status::OK();
+ }
+ if (ImplementedAsDnnConvolution(*fusion)) {
+ thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion));
+ return Status::OK();
+ }
+ thunk_sequence_->emplace_back(BuildKernelThunk(fusion));
+ return IrEmitter::HandleFusion(fusion);
+}
+
+namespace {
+
+// Returns the indices of the first elements of all consecutive subarrays of the
+// given array. For example:
+// ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
+std::vector<size_t> ConsecutiveSegments(tensorflow::gtl::ArraySlice<int64> xs) {
+ std::vector<size_t> is = {0};
+ for (size_t i = 1; i < xs.size(); ++i) {
+ if (1 != xs[i] - xs[i - 1]) {
+ is.push_back(i);
+ }
+ }
+ return is;
+}
+
+// Merges the sequences of dimensions of the given shape which start at the
+// given indices `segs`.
+Shape MergeDimensions(tensorflow::gtl::ArraySlice<size_t> segs,
+ const Shape& shape) {
+ std::vector<int64> dimensions;
+ for (size_t i = 1; i <= segs.size(); ++i) {
+ dimensions.push_back(std::accumulate(
+ shape.dimensions().begin() + segs[i - 1],
+ shape.dimensions().begin() +
+ (segs.size() == i ? shape.dimensions().size() : segs[i]),
+ 1, std::multiplies<int64>()));
+ }
+ return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(),
+ dimensions);
+}
+
+// Returns whether the given shapes and permutation are a 0-2-1 transpose, and
+// if so, the normalized and rank-reduced shapes. The shapes must have the same
+// dimensions, so this considers layout only.
+//
+// This function recognizes higher-rank transposes which are elementwise
+// equivalent to a 0-2-1 transpose.
+std::tuple<bool, Shape, Shape> IsTranspose021(const Shape& a, const Shape& b) {
+ CHECK(ShapeUtil::Compatible(a, b));
+ std::vector<int64> perm(a.dimensions().size());
+ {
+ std::vector<int64> layout_a(a.layout().minor_to_major().rbegin(),
+ a.layout().minor_to_major().rend());
+ std::vector<int64> layout_b(b.layout().minor_to_major().rbegin(),
+ b.layout().minor_to_major().rend());
+ for (size_t i = 0; i < perm.size(); ++i) {
+ perm[i] = PositionInContainer(layout_b, layout_a[i]);
+ }
+ }
+ auto segs = ConsecutiveSegments(perm);
+ Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a);
+ Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b);
+ if (3 == segs.size() && 0 == perm[0]) {
+ Shape reduced_a = MergeDimensions(segs, norm_a);
+ Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ b.element_type(),
+ Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions())));
+ return std::make_tuple(true, reduced_a, reduced_b);
+ }
+ return std::make_tuple(false, ShapeUtil::MakeNil(), ShapeUtil::MakeNil());
+}
+
+// Returns whether the given shapes are potentially of a 0-2-1 transpose.
+// As 0-2-1 is a self-inverse permutation, which shape is input or output is
+// arbitrary.
+bool AreShapesForTranspose021(const Shape& a, const Shape& b) {
+ return 3 == b.dimensions().size() &&
+ ShapeUtil::Compatible(
+ ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a),
+ ShapeUtil::PermuteDimensions(
+ {0, 2, 1},
+ ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b)));
+}
+
+// Emits a tiled 0-2-1 transpose, assuming both input and output lain out from
+// major to minor. The x- and y- dimensions are tiled in square tiles of edge
+// length `tile_size`. Each thread block of `tile_size` threads transposes one
+// tile: each thread copies a row from the input to a shared memory tile, then
+// copies a column from the shared memory tile to the output.
+//
+// `tile_size` should usually be same as warp size.
+//
+// Returns (number of tiles = number of thread blocks needed).
+//
+// TODO(b/33320379): Here each block transposes 1 tile. It may be more efficient
+// to launch fewer blocks so each transposes many tiles, and
+// in any case, the number of blocks we can launch is limited.
+//
+// This is the same algorithm in CUDA:
+// https://github.com/tensorflow/tensorflow/blob/6172351b81af76d0b819fea6bb478cbd4016d6c2/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc#L183
+int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output,
+ const int64 tile_size, llvm::IRBuilder<>* builder) {
+ // Adds `addend` to the given `dim` of `index`.
+ auto offset_dim = [builder](llvm_ir::IrArray::Index index,
+ llvm::Value* addend, int64 dim) {
+ index[dim] = builder->CreateAdd(index[dim], addend);
+ return index;
+ };
+
+ CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape()));
+
+ Shape input_shape =
+ ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape());
+ Shape output_shape =
+ ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape());
+ input = input.CastToShape(input_shape, builder);
+ output = output.CastToShape(output_shape, builder);
+
+ llvm::Type* tile_type = llvm::ArrayType::get(
+ llvm::ArrayType::get(input.GetElementLlvmType(), tile_size),
+ // One extra here to avoid share memory bank conflict
+ tile_size + 1);
+ auto* tile = new llvm::GlobalVariable(
+ *builder->GetInsertBlock()->getParent()->getParent(), tile_type,
+ /*isConstant=*/false, llvm::GlobalValue::PrivateLinkage,
+ llvm::UndefValue::get(tile_type), "tile", nullptr,
+ llvm::GlobalValue::NotThreadLocal,
+ /*AddressSpace=*/3 /* GPU shared memory */);
+
+ // let x = threadIdx.x
+ llvm::Value* x = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, builder);
+ llvm_ir::AddRangeMetadata(0, tile_size, static_cast<llvm::Instruction*>(x));
+ x = builder->CreateIntCast(x, builder->getInt64Ty(), /*isSigned=*/true,
+ "thread.id.x");
+
+ // `emit_cp` emits equivalent to following pseudocode:
+ // if (tile_size == tile_width && tile_size == tile_height) {
+ // unroll for (y in 0..tile_size) {
+ // emit_cp_element(index + {0, y, 0}, y);
+ // }
+ // } else if (x < tile_width) {
+ // for (y in 0..tile_height) {
+ // emit_cp_element(index + {0, y, 0}, y);
+ // }
+ // }
+ //
+ // We use this to emit both the copy from input to tile and the copy from tile
+ // to output.
+ //
+ // `index` is the origin of the row or column in the input or output array.
+ //
+ // `emit_cp_element(index, y)` emits code to copy a single element between the
+ // tile and the input or output array, where `y` is the `y`-position in the
+ // tile, whether which is row or column is a function of whether we're copying
+ // from input or to output, and `index` is the index into the input or output
+ // array.
+ auto emit_cp_tile = [builder, tile_size, x, &offset_dim](
+ std::function<void(const llvm_ir::IrArray::Index&, llvm::Value*)>
+ emit_cp_element,
+ llvm::Value* tile_width, llvm::Value* tile_height,
+ const llvm_ir::IrArray::Index& index, const string& loop_name) {
+ llvm_ir::LlvmIfData if_not_last_row = llvm_ir::EmitIfThenElse(
+ builder->CreateAnd(
+ builder->CreateICmpEQ(builder->getInt64(tile_size), tile_width),
+ builder->CreateICmpEQ(builder->getInt64(tile_size), tile_height)),
+ "not_last_row", builder);
+ builder->SetInsertPoint(if_not_last_row.true_block->getTerminator());
+ for (int64 i = 0; i < tile_size; ++i) {
+ emit_cp_element(offset_dim(index, builder->getInt64(i), /*dim=*/1),
+ builder->getInt64(i));
+ }
+ builder->SetInsertPoint(if_not_last_row.false_block->getTerminator());
+ llvm_ir::LlvmIfData if_in_tile = llvm_ir::EmitIfThenElse(
+ builder->CreateICmpULT(x, tile_width), "in_tile", builder);
+ builder->SetInsertPoint(if_in_tile.true_block->getTerminator());
+ auto loop = llvm_ir::ForLoop::EmitForLoop(loop_name, builder->getInt64(0),
+ tile_height, builder->getInt64(1),
+ builder);
+ llvm_ir::SetToFirstInsertPoint(loop->GetHeaderBasicBlock(), builder);
+ builder->SetInsertPoint(loop->GetBodyBasicBlock()->getTerminator());
+ emit_cp_element(offset_dim(index, loop->GetIndVarValue(), /*dim=*/1),
+ loop->GetIndVarValue());
+ builder->SetInsertPoint(if_not_last_row.after_block->getTerminator());
+ };
+
+ auto input_dims_in_tiles = input_shape.dimensions();
+ // Unpermuted dimensions are untiled.
+ for (int i = 1; i < 3; ++i) {
+ input_dims_in_tiles[i] =
+ CeilOfRatio<int64>(input_dims_in_tiles[i], tile_size);
+ }
+ int64 num_tiles =
+ std::accumulate(input_dims_in_tiles.begin(), input_dims_in_tiles.end(), 1,
+ std::multiplies<int64>());
+ const llvm_ir::IrArray::Index input_tile_index(
+ /*linear=*/builder->CreateIntCast(
+ llvm_ir::AddRangeMetadata(
+ 0, num_tiles,
+ static_cast<llvm::Instruction*>(llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {},
+ builder))),
+ builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"),
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)),
+ builder);
+ const llvm_ir::IrArray::Index input_tile_origin = ({
+ llvm_ir::IrArray::Index index = input_tile_index;
+ for (int i = 1; i < 3; ++i) {
+ index[i] = builder->CreateMul(index[i], builder->getInt64(tile_size),
+ "tile_origin." + std::to_string(i));
+ }
+ index;
+ });
+ const llvm_ir::IrArray::Index input_index =
+ offset_dim(input_tile_origin, x, /*dim=*/2);
+ std::vector<llvm::Value*> tile_dims(input_shape.dimensions().size());
+ // Only last row or column may not have full size.
+ for (int i = 1; i < 3; ++i) {
+ tile_dims[i] = builder->CreateSelect(
+ builder->CreateICmpEQ(input_tile_index[i],
+ builder->getInt64(input_dims_in_tiles[i] - 1)),
+ builder->getInt64(input_shape.dimensions(i) -
+ (input_dims_in_tiles[i] - 1) * tile_size),
+ builder->getInt64(tile_size), "tile_size");
+ }
+
+ // Load data from input memory to shared memory tile.
+ emit_cp_tile(
+ // tile[y, x] = input_array[index]
+ [builder, tile, x, &input](const llvm_ir::IrArray::Index& index,
+ llvm::Value* y) {
+ builder->CreateStore(
+ input.EmitReadArrayElement(index, builder, "input_element"),
+ builder->CreateGEP(tile, {builder->getInt64(0), y, x}));
+ },
+ tile_dims[2], tile_dims[1], input_index, "input");
+
+ // Wait for all threads to reach this point, lest we copy a value from tile to
+ // output before the other thread copies it from input to tile.
+ // This is `__syncthreads` in CUDA.
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, builder);
+
+ const llvm_ir::IrArray::Index output_tile_index(
+ Permute({0, 2, 1}, input_tile_index.multidim()));
+ const llvm_ir::IrArray::Index output_tile_origin(
+ Permute({0, 2, 1}, input_tile_origin.multidim()));
+ const llvm_ir::IrArray::Index output_index =
+ offset_dim(output_tile_origin, x, /*dim=*/2);
+
+ // Store data from shared memory tile to output memory.
+ emit_cp_tile(
+ // output_array[index] = tile[x, y]
+ [builder, tile, x, &output](const llvm_ir::IrArray::Index& index,
+ llvm::Value* y) {
+ output.EmitWriteArrayElement(
+ index,
+ builder->CreateLoad(
+ builder->CreateGEP(tile, {builder->getInt64(0), x, y}),
+ "output_element"),
+ builder);
+ },
+ tile_dims[1], tile_dims[2], output_index, "output");
+
+ return num_tiles;
+}
+
+} // namespace
+
+Status IrEmitterUnnested::HandleCopy(HloInstruction* copy,
+ HloInstruction* operand) {
+ if (ImplementedAsMemcpy(*copy)) {
+ thunk_sequence_->emplace_back(BuildCopyThunk(copy));
+ return Status::OK();
+ }
+ bool is_transpose_021;
+ Shape reduced_input_shape, reduced_output_shape;
+ std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) =
+ IsTranspose021(operand->shape(), copy->shape());
+ if (is_transpose_021 &&
+ reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled &&
+ reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) {
+ thunk_sequence_->emplace_back(BuildKernelThunk(copy));
+ VLOG(3) << "Emitting tiled 0-2-1 transposition";
+ constexpr int64 tile_size = 32;
+ int64 num_tiles = EmitTranspose021Tiled(
+ GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_),
+ GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_),
+ tile_size, &ir_builder_);
+ UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(),
+ ir_emitter_context_->llvm_module());
+ return Status::OK();
+ }
+
+ return IrEmitter::HandleCopy(copy, operand);
+}
+
+Status IrEmitterUnnested::EmitColumnReduction(
+ int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+ // Divide the input matrix into tiles of size Kx1. For example, when the
+ // input matrix is 4x4 and K=2, the tiled matrix looks like
+ //
+ // 0123
+ // 0123
+ // 4567
+ // 4567 // Numbers indicate tile IDs.
+ //
+ // Each tile is first partially reduced to a scalar by a thread, and then the
+ // scalar is accumulated to the output vector using atomic operations. We
+ // choose 16 as the tile size, which matches Eigen's ColumnReduceKernel.
+ constexpr int64 kTileSize = 16;
+ // If the height is not a multiple of the tile size, we pad the bottom of the
+ // input matrix.
+ const int64 height_in_tiles = CeilOfRatio(height, kTileSize);
+
+ // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
+ // linear_index < height_in_tiles * width;
+ // linear_index += blockDim.x * gridDim.x) {
+ // y_in_tiles = linear_index / width;
+ // x = linear_index % width;
+ //
+ // partial_result = init_value;
+ // if (height % kTileSize == 0 ||
+ // y_in_tiles * kTileSize + kTileSize <= height) {
+ // for (element_id_in_tile : range(kTileSize)) {
+ // y = y_in_tiles * kTileSize + element_id_in_tile;
+ // partial_result = Reducer(partial_result, input[y][x]);
+ // }
+ // } else {
+ // for (element_id_in_tile : range(kTileSize)) {
+ // y = y_in_tiles * kTileSize + element_id_in_tile;
+ // if (y < height) {
+ // partial_result = Reducer(partial_result, input[y][x]);
+ // }
+ // }
+ // }
+ // AtomicReducer(&output[x], partial_result);
+ // }
+ auto loop_body_emitter =
+ [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ // Emit the loop body that reduces one tile.
+ llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
+ input_shape.element_type(), &ir_builder_);
+ llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
+ {
+ TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
+ init_value_gen(llvm_ir::IrArray::Index({})));
+ ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ }
+
+ // Emit an inner for-loop that partially reduces the elements in the given
+ // tile.
+ llvm::Value* y_in_tiles = tile_index[0];
+ llvm::Value* x = tile_index[1];
+
+ auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
+ std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
+ llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
+ ir_builder_.getInt64(0),
+ ir_builder_.getInt64(kTileSize),
+ ir_builder_.getInt64(1), &ir_builder_);
+
+ // Emit the body of the partial reduction loop.
+ llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
+ &ir_builder_);
+ llvm::Value* y = ir_builder_.CreateNSWAdd(
+ ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)),
+ tile_element_loop->GetIndVarValue());
+ // Unless we know the tile is entirely in bounds, we have to emit a
+ // y-in-bounds check before reading from the input.
+ if (!tile_in_bounds) {
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ ir_builder_.CreateICmpULT(y, ir_builder_.getInt64(height)),
+ "y_in_bounds", &ir_builder_);
+
+ // Emit code that reads the input element and accumulates it to
+ // the partial reduction result.
+ llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
+ }
+ llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
+ {
+ // {y,x} is an index to input_matrix_shape [height,width]. We need to
+ // convert that to an index to input_shape (the shape of the operand of
+ // "reduce"). This conversion is composed of a transposition from
+ // input_shape to normalized_input_shape and a reshape from
+ // normalized_input_shape to input_matrix_shape.
+ const Shape normalized_input_shape =
+ ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape);
+ const std::vector<int64> transpose_dimension_mapping(
+ input_shape.layout().minor_to_major().rbegin(),
+ input_shape.layout().minor_to_major().rend());
+
+ const Shape input_matrix_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ input_shape.element_type(), {height, width});
+ const llvm_ir::IrArray::Index input_matrix_index(
+ {y, x}, input_matrix_shape, &ir_builder_);
+ const llvm_ir::IrArray::Index input_index =
+ input_matrix_index
+ .SourceIndexOfReshape(input_matrix_shape,
+ normalized_input_shape, &ir_builder_)
+ .SourceIndexOfTranspose(normalized_input_shape, input_shape,
+ transpose_dimension_mapping,
+ &ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
+ input_gen(input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ }
+ return (EmitCallToNestedComputation(
+ *reducer, {partial_reduction_result_address, input_address},
+ partial_reduction_result_address));
+ };
+
+ // y_end = kTileSize + y_in_tiles * kTileSize, i.e., the y location that's
+ // immediately beyond the tile.
+ llvm::Value* y_end = ir_builder_.CreateNSWAdd(
+ ir_builder_.getInt64(kTileSize),
+ ir_builder_.CreateNSWMul(y_in_tiles, ir_builder_.getInt64(kTileSize)));
+ llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
+ ir_builder_.CreateICmpULE(y_end, ir_builder_.getInt64(height)),
+ ir_builder_.getInt1(height % kTileSize == 0));
+ // The tile is entirely in bound if "height" is a multiple of kTileSize or
+ // y_end <= height.
+ llvm_ir::LlvmIfData if_tile_in_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
+ &ir_builder_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
+ &ir_builder_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
+
+ // After the if-then-else statement on tile_in_bounds, emit atomic
+ // operations to accumulate the partial reduction result to the output
+ // element.
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
+ &ir_builder_);
+ const HloInstruction* output =
+ reduce->IsFused() ? reduce->fusion_instruction() : reduce;
+ llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
+ llvm_ir::IrArray::Index(x, output->shape(), &ir_builder_), &ir_builder_,
+ "output_element_address");
+ return EmitAtomicOperationForNestedComputation(
+ *reducer, output_address, partial_reduction_result_address);
+ };
+
+ // Emit a parallel loop that iterate through all input tiles.
+ Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
+ reduce->shape().element_type(), {height_in_tiles, width}, {1, 0});
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ tiled_input_shape, ir_emitter_context_->device_description());
+ CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
+ UpdateLaunchDimensions(
+ launch_dimensions,
+ static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
+ ir_emitter_context_->llvm_module());
+ return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
+ launch_dimensions, &ir_builder_)
+ .EmitLoop();
+}
+
+Status IrEmitterUnnested::EmitRowReduction(
+ int64 depth, int64 height, int64 width, HloInstruction* reduce,
+ const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) {
+ // A naive algorithm is:
+ // 1. Divide the input tensor into tiles of size 1x1xK.
+ // 2. Partially reduces each tile to a scalar using one thread.
+ // 3. Accumulates that scalar to the output vector using atomic operations.
+ //
+ // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
+ // linear_index < depth * height * width_in_tiles;
+ // linear_index += blockDim.x * gridDim.x) {
+ // int x_in_tiles = linear_index % width_in_tiles;
+ // int y = linear_index / width_in_tiles % height;
+ // int z = linear_index / (height * width_in_tiles);
+ // float partial_result = 0;
+ // for (element_id_in_tile : range(kTileSize)) {
+ // int x = x_in_tiles * kTileSize + element_id_in_tile;
+ // if (x < width)
+ // partial_result = reducer(partial_result, input[z][y][z]);
+ // }
+ // AtomicReducer(&output[y], partial_result);
+ // }
+ //
+ // Three optimizations are performed.
+ //
+ // 1. To coalesc global memory accesses, dilate the tile with a factor of 32
+ // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead
+ // of making each tile consecutive, we let make tile 0 column
+ // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures
+ // that threads in a warp access consecutive memory in one iteration (i.e.
+ // coalesced). In the above example, the warp that contains thread 0-31
+ // accesses column 0-31 in the first iteration, and 32-63 in the second
+ // iteration, and so on.
+ //
+ // 2. Partially accumulate partial reduced results computed by threads in the
+ // same warp using shfl_down. Using shfl_down is faster than directly using
+ // atomic operations because shfl_down transfers the data between threads
+ // using shared memory and threads in the same warp run in lock step (thus no
+ // extra synchronization needed). See
+ // https://devblogs.nvidia.com/parallelforall/faster-parallel-reductions-kepler/
+ // for details. The downside is, to produce correct results when using
+ // shfl_down, we need to guarantee threads in the same warp work on input
+ // elements with the same y, so the number of tiles in each row must be a
+ // multiple of 32.
+ //
+ // 3. Specialize the case that the entire tile is in bounds. When that is
+ // true, we don't need to emit "if(x<width)" inside the loop on
+ // element_id_in_tile, which makes the code more friendly to optimizations
+ // such as LICM.
+ //
+ // for (linear_index = threadIdx.x + blockIdx.x * blockDim.x;
+ // linear_index < depth * height * width_in_tiles;
+ // linear_index += blockDim.x * gridDim.x) {
+ // int x_in_tiles = linear_index % width_in_tiles;
+ // int y = linear_index / width_in_tiles % height;
+ // int z = linear_index / (height * width_in_tiles);
+ // int warp_id = x_in_tiles / warpSize;
+ // int lane_id = x_in_tiles % warpSize;
+ // float partial_result = 0;
+ // int x = warp_id * kTileSize * warpSize + lane_id;
+ // if (width % (kTileSize * warpSize) == 0 ||
+ // x + (kTileSize - 1) * warpSize < width) {
+ // // The entire tile is in bounds.
+ // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
+ // ++element_id_in_tile, x += warpSize) {
+ // partial_result = Reducer(partial_result, input[z][y][x]);
+ // }
+ // } else {
+ // // The tile is partially in bounds.
+ // for (int element_id_in_tile = 0; element_id_in_tile < kTileSize;
+ // ++element_id_in_tile, x += warpSize) {
+ // if (x < width)
+ // partial_result = Reducer(partial_result, input[z][y][x]);
+ // }
+ // }
+ // for (shuffle_distance = 16; shuffle_distance > 0; shuffle_distance /= 2)
+ // partial_result = Reducer(
+ // partial_result,
+ // __shfl_down(partial_result, shuffle_distance));
+ // if (lane_id == 0)
+ // AtomicReducer(&output[y], partial_result);
+ // }
+ //
+ // Choose 8 as the tile size, which matches Eigen's RowReduceKernel.
+ constexpr int64 kTileSize = 8;
+ // Round the width in tiles up to the nearest multiple of kWarpSize, so that
+ // the use of shfl_down is valid.
+ const int64 width_in_tiles =
+ RoundUpToNearest(CeilOfRatio(width, kTileSize), kWarpSize);
+
+ auto loop_body_emitter =
+ [=](const llvm_ir::IrArray::Index& tile_index) -> Status {
+ // Emit the loop body that reduces one tile.
+ llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType(
+ input_shape.element_type(), &ir_builder_);
+ llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca(
+ element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result");
+ {
+ TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value,
+ init_value_gen(llvm_ir::IrArray::Index({})));
+ ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address);
+ }
+
+ // Emit an inner for-loop that partially reduces the elements in the given
+ // tile.
+ llvm::Value* z = tile_index[0];
+ llvm::Value* y = tile_index[1];
+ llvm::Value* x_tile = tile_index[2];
+ llvm::Value* warp_id = ir_builder_.CreateUDiv(
+ x_tile, ir_builder_.getInt64(kWarpSize), "warp_id");
+ llvm::Value* lane_id = ir_builder_.CreateURem(
+ x_tile, ir_builder_.getInt64(kWarpSize), "lane_id");
+
+ // The x-location of the last element in this tile.
+ // last_x = lane_id + warpSize * (kTileSize - 1 + warp_id * kTileSize);
+ llvm::Value* last_x = ir_builder_.CreateNSWAdd(
+ lane_id,
+ ir_builder_.CreateNSWMul(
+ ir_builder_.getInt64(kWarpSize),
+ ir_builder_.CreateNSWAdd(
+ ir_builder_.getInt64(kTileSize - 1),
+ ir_builder_.CreateNSWMul(warp_id,
+ ir_builder_.getInt64(kTileSize)))));
+
+ auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status {
+ std::unique_ptr<llvm_ir::ForLoop> tile_element_loop =
+ llvm_ir::ForLoop::EmitForLoop("element_id_in_tile",
+ ir_builder_.getInt64(0),
+ ir_builder_.getInt64(kTileSize),
+ ir_builder_.getInt64(1), &ir_builder_);
+
+ // Emit the body of the partial reduction loop.
+ llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(),
+ &ir_builder_);
+ // x = lane_id + warpSize * (element_id_in_tile + warp_id * kTileSize);
+ llvm::Value* x = ir_builder_.CreateNSWAdd(
+ lane_id,
+ ir_builder_.CreateNSWMul(
+ ir_builder_.getInt64(kWarpSize),
+ ir_builder_.CreateNSWAdd(
+ tile_element_loop->GetIndVarValue(),
+ ir_builder_.CreateNSWMul(warp_id,
+ ir_builder_.getInt64(kTileSize)))));
+
+ // Unless we know the tile is entirely in bounds, we have to emit a
+ // x-in-bounds check before reading from the input.
+ if (!tile_in_bounds) {
+ llvm_ir::LlvmIfData if_x_in_bounds_data = llvm_ir::EmitIfThenElse(
+ ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(width)),
+ "x_in_bounds", &ir_builder_);
+
+ // Points ir_builder_ to the then-block.
+ llvm_ir::SetToFirstInsertPoint(if_x_in_bounds_data.true_block,
+ &ir_builder_);
+ }
+
+ // Emit code that reads the input element and accumulates it to the
+ // partial reduction result.
+ llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type);
+ {
+ // {z,y,x} is an index to input_3d_tensor_shape [depth,height,width]. We
+ // need to convert that to an index to input_shape (the shape of the
+ // operand of "reduce"). This conversion is composed of a transposition
+ // from input_shape to normalized_input_shape and a reshape from
+ // normalized_input_shape to input_3d_tensor_shape.
+ const Shape normalized_input_shape =
+ ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape);
+ const std::vector<int64> transpose_dimension_mapping(
+ input_shape.layout().minor_to_major().rbegin(),
+ input_shape.layout().minor_to_major().rend());
+ const Shape input_3d_tensor_shape =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ input_shape.element_type(), {depth, height, width});
+ const llvm_ir::IrArray::Index input_3d_tensor_index(
+ {z, y, x}, input_3d_tensor_shape, &ir_builder_);
+ const llvm_ir::IrArray::Index input_index =
+ input_3d_tensor_index
+ .SourceIndexOfReshape(input_3d_tensor_shape,
+ normalized_input_shape, &ir_builder_)
+ .SourceIndexOfTranspose(normalized_input_shape, input_shape,
+ transpose_dimension_mapping,
+ &ir_builder_);
+ TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value,
+ input_gen(input_index));
+ ir_builder_.CreateStore(input_ir_value, input_address);
+ }
+ return EmitCallToNestedComputation(
+ *reducer, {partial_reduction_result_address, input_address},
+ partial_reduction_result_address);
+ };
+
+ llvm::Value* tile_in_bounds = ir_builder_.CreateOr(
+ ir_builder_.getInt1(width % (kTileSize * kWarpSize) == 0),
+ ir_builder_.CreateICmpULT(last_x, ir_builder_.getInt64(width)));
+ llvm_ir::LlvmIfData if_tile_in_bounds_data =
+ llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block,
+ &ir_builder_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true));
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block,
+ &ir_builder_);
+ TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false));
+
+ // After the if-then-else statement on tile_in_bounds, emit calls to
+ // shfl_down that accumulate the partial reduction results of all threads
+ // from the warp.
+ llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block,
+ &ir_builder_);
+ for (int shuffle_distance = 16; shuffle_distance >= 1;
+ shuffle_distance /= 2) {
+ llvm::Value* partial_reduction_result = ir_builder_.CreateLoad(
+ partial_reduction_result_address, "partial_reduction_result");
+ llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca(
+ element_ir_type, nullptr, "result_from_other_lane");
+ ir_builder_.CreateStore(
+ EmitShuffleDown(partial_reduction_result,
+ ir_builder_.getInt32(shuffle_distance), &ir_builder_),
+ result_from_other_lane);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *reducer, {partial_reduction_result_address, result_from_other_lane},
+ partial_reduction_result_address));
+ }
+
+ const HloInstruction* output =
+ reduce->IsFused() ? reduce->fusion_instruction() : reduce;
+
+ // Emit an atomic operation that accumulates the partial reduction result of
+ // lane 0 (which holds the partially accumulated result for its warp) to the
+ // output element.
+ llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
+ ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)),
+ "lane_id_is_zero", &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block,
+ &ir_builder_);
+ llvm::Value* output_address = GetIrArray(*output).EmitArrayElementAddress(
+ llvm_ir::IrArray::Index(y, output->shape(), &ir_builder_), &ir_builder_,
+ "output_element_address");
+ return EmitAtomicOperationForNestedComputation(
+ *reducer, output_address, partial_reduction_result_address);
+ };
+
+ // Emit a parallel loop that iterates through every input tiles.
+ Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout(
+ reduce->shape().element_type(), {depth, height, width_in_tiles},
+ {2, 1, 0});
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ tiled_input_shape, ir_emitter_context_->device_description());
+ CHECK(LastThunk()->kind() == Thunk::Kind::kSequential);
+ UpdateLaunchDimensions(
+ launch_dimensions,
+ static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
+ ir_emitter_context_->llvm_module());
+ return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape,
+ launch_dimensions, &ir_builder_)
+ .EmitLoop();
+}
+
+// Figures out whether `reduce` is a row or column reduction, and which
+// dimensions to reduce, and calls either `EmitRowReduction` or
+// `EmitColumnReduction` as appropriate.
+// Prerequisite: the shape of `reduce` has rank 1 and, if `reduce` is fused, the
+// fused subgraph is pure elementwise.
+Status IrEmitterUnnested::EmitReductionToVector(
+ HloInstruction* reduce, const Shape& input_shape,
+ const llvm_ir::ElementGenerator& input_gen,
+ const llvm_ir::ElementGenerator& init_value_gen,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reducer) {
+ // This emission requires "reduce" to have an input layout. It is either set
+ // by LayoutAssignment (for a top-level kReduce) or by InstructionFusion (for
+ // a fused kReduce).
+ CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
+ "doesn't set the input layout of "
+ << reduce->ToString();
+
+ // Specialize multi-dimensional-array-to-vector reduction.
+ //
+ // TODO(b/33239522): we could use the same algorithm for general reduction
+ // as long as the input dimensions to keep are adjacent in the layout and
+ // have the same relative layout as their corresponding output dimensions.
+ // For example, reducing shape [2,3,4,5] with minor_to_major={2,0,1,3} to
+ // shape [2,4] with minor_to_major={1,0} can be implemented as a column
+ // reduction from shape [15,8] to shape [8].
+ int64 input_dim_to_keep = -1;
+ for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
+ ++input_dim) {
+ if (std::find(dimensions_to_reduce.begin(), dimensions_to_reduce.end(),
+ input_dim) == dimensions_to_reduce.end()) {
+ input_dim_to_keep = input_dim;
+ break;
+ }
+ }
+ CHECK_NE(-1, input_dim_to_keep);
+
+ if (LayoutUtil::Minor(input_shape.layout(), 0) == input_dim_to_keep) {
+ // Column reduction. Treat the result of "input" as a matrix whose width
+ // is the most minor dimension and height the product of other dimensions,
+ // and treat "reduce" as a column reduction of the input matrix.
+ const int64 width = ShapeUtil::ElementsIn(reduce->shape());
+ // "width" can be zero, so don't do
+ // height = ShapeUtil::ElementsIn(input_shape) / width;
+ int64 height = 1;
+ for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
+ ++input_dim) {
+ if (input_dim != input_dim_to_keep) {
+ height *= input_shape.dimensions(input_dim);
+ }
+ }
+ return EmitColumnReduction(height, width, reduce, input_shape, input_gen,
+ init_value_gen, reducer);
+ } else {
+ // Reduce the row dimension of a matrix or reduce dimension 0 and 2 in a
+ // 3D tensor. The size of dimension 1 (the height) is the size of the
+ // dimension to keep, the size of dimension 0 (the depth) is the product
+ // of dimensions that are more major than the dimension to keep, and the
+ // size of dimension 2 (the width) is the product of more minor
+ // dimensions.
+ int64 depth = 1;
+ int64 width = 1;
+ for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
+ ++input_dim) {
+ if (PositionInContainer(
+ AsInt64Slice(input_shape.layout().minor_to_major()), input_dim) >
+ PositionInContainer(
+ AsInt64Slice(input_shape.layout().minor_to_major()),
+ input_dim_to_keep)) {
+ depth *= input_shape.dimensions(input_dim);
+ } else if (PositionInContainer(
+ AsInt64Slice(input_shape.layout().minor_to_major()),
+ input_dim) <
+ PositionInContainer(
+ AsInt64Slice(input_shape.layout().minor_to_major()),
+ input_dim_to_keep)) {
+ width *= input_shape.dimensions(input_dim);
+ }
+ }
+ int64 height = input_shape.dimensions(input_dim_to_keep);
+ return EmitRowReduction(depth, height, width, reduce, input_shape,
+ input_gen, init_value_gen, reducer);
+ }
+}
+
+Status IrEmitterUnnested::HandleReduce(
+ HloInstruction* reduce, HloInstruction* input, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reducer) {
+ // HandleReduce specializes reduction from a multi-dimensional array to a 1D
+ // array. The specialized version requires an initializer thunk that
+ // initializes the output array to the initial value of the reduce.
+ if (IsReductionToVector(*reduce) &&
+ // NVPTX backend can't do atomic cmpxchg any narrower than 32 bits
+ 32 <= primitive_util::BitWidth(reduce->shape().element_type())) {
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ thunks.emplace_back(BuildKernelThunk(reduce));
+ TF_RETURN_IF_ERROR(EmitInitializer(
+ reduce, static_cast<KernelThunk*>(thunks.back().get())));
+ bindings_.UnbindAllLocalIrValues();
+ thunks.emplace_back(BuildKernelThunk(reduce));
+ thunk_sequence_->emplace_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), reduce));
+ return EmitReductionToVector(
+ reduce, input->shape(),
+ [this, input](const llvm_ir::IrArray::Index& index) {
+ return GetIrArray(*input).EmitReadArrayElement(index, &ir_builder_);
+ },
+ [this, init_value](const llvm_ir::IrArray::Index& index) {
+ return GetIrArray(*init_value)
+ .EmitReadArrayElement(index, &ir_builder_);
+ },
+ dimensions_to_reduce, reducer);
+ }
+
+ thunk_sequence_->emplace_back(BuildKernelThunk(reduce));
+ return IrEmitter::HandleReduce(reduce, input, init_value,
+ dimensions_to_reduce, reducer);
+}
+
+Status IrEmitterUnnested::HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ bool all_tuple_elements_have_buffer = std::all_of(
+ operands.begin(), operands.end(), [this](HloInstruction* tuple_element) {
+ return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation(
+ tuple_element);
+ });
+ // Tuples (especially output tuples) can take too many tuple elements,
+ // causing the kernel emitted exceeds the parameter space limit
+ // (b/31336476). As an optimization, if all tuple elements have a buffer, we
+ // collect their buffer addresses in a host array, and then copy that array
+ // to the tuple's buffer.
+ //
+ // Some tuple elements (e.g. const or bitcast of const) might not have a
+ // buffer -- their contents are stored in code. In that case, we fall back
+ // to emitting kernels which have access to their buffer addresses in code.
+ if (all_tuple_elements_have_buffer) {
+ std::vector<BufferAllocation::Index> tuple_element_buffers;
+ for (const HloInstruction* tuple_element : operands) {
+ tuple_element_buffers.push_back(GetAllocationIndex(*tuple_element));
+ }
+ thunk_sequence_->emplace_back(MakeUnique<TupleThunk>(
+ tuple_element_buffers, GetAllocationIndex(*tuple), tuple));
+ return Status::OK();
+ }
+ // If `inst` is a nested thunk that can be disassembled from the result tuple,
+ // GpuExecutable will disassemble it and return it as part of the resultant
+ // ShapedBuffer.
+ if (has_hybrid_result_ &&
+ ReachRootViaOnlyTuples(*tuple, *hlo_computation_->root_instruction())) {
+ return Status::OK();
+ }
+ thunk_sequence_->emplace_back(BuildKernelThunk(tuple));
+ return IrEmitter::HandleTuple(tuple, operands);
+}
+
+Status IrEmitterUnnested::HandleGetTupleElement(
+ HloInstruction* get_tuple_element, HloInstruction* operand) {
+ // GetTupleElement IR is emitted in the IR context of the user instruction,
+ // and so we do not build a kernel for GetTupleElement instructions.
+ return Status::OK();
+}
+
+Status IrEmitterUnnested::HandleSelectAndScatter(
+ HloInstruction* select_and_scatter) {
+ CHECK_EQ(select_and_scatter->operand_count(), 3);
+ const auto* operand = select_and_scatter->operand(0);
+ const auto* source = select_and_scatter->operand(1);
+ const Window& window = select_and_scatter->window();
+ PrimitiveType operand_element_type = operand->shape().element_type();
+ const int64 rank = ShapeUtil::Rank(operand->shape());
+ CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
+ CHECK_EQ(rank, window.dimensions_size());
+
+ {
+ std::vector<std::unique_ptr<Thunk>> thunks;
+ thunks.emplace_back(BuildKernelThunk(select_and_scatter));
+ TF_RETURN_IF_ERROR(EmitInitializer(
+ select_and_scatter, static_cast<KernelThunk*>(thunks.back().get())));
+ bindings_.UnbindAllLocalIrValues();
+ thunks.emplace_back(BuildKernelThunk(select_and_scatter));
+ thunk_sequence_->emplace_back(
+ MakeUnique<SequentialThunk>(std::move(thunks), select_and_scatter));
+ }
+
+ // TODO(b/31410564): Implement dilation rate for select-and-scatter.
+ if (window_util::HasDilation(window)) {
+ return Unimplemented(
+ "Dilation for select-and-scatter not implemented on GPU. "
+ "See b/31410564.");
+ }
+
+ // kSelectAndScatter is implemented as two kernel launches: the first launch
+ // initializes the output array to the given initial value,
+ // and the second accumulates the "source" matrix to the
+ // selected elements in the output array. The first launch is already
+ // implemented by the initializer thunk generated earlier, so this function
+ // only needs to take care of the select-and-scatter part.
+ //
+ // Pseudo code for select-and-scatter:
+ //
+ // for (coordinates S in the source): # This loop is parallel.
+ // initialized_flag = false
+ // for (coordinates W in the window):
+ // I = S * stride + W - pad_low
+ // if I within bounds of operand:
+ // if !(initialized_flag and select(selected_value, operand(I))):
+ // selected_value = operand(I)
+ // selected_index = I
+ // initialized_flag = true
+ // output(selected_index) = scatter(output(selected_index), source(S))
+ auto loop_body_emitter =
+ [=](const llvm_ir::IrArray::Index& source_index) -> Status {
+ // Allocate space to keep the currently selected value, its index, and a
+ // boolean flag if the value is initialized. The initialized_flag is set
+ // false.
+ llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
+ "selected_value_address", &ir_builder_);
+ llvm::Value* selected_index_address =
+ llvm_ir::EmitAllocaAtFunctionEntryWithCount(
+ ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
+ "selected_index_address", &ir_builder_);
+ llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
+ ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
+ ir_builder_.CreateStore(ir_builder_.getInt1(false),
+ initialized_flag_address);
+
+ // Create the inner loop to iterate over the window.
+ llvm_ir::ForLoopNest window_loops(&ir_builder_);
+ std::vector<int64> window_size;
+ for (const auto& dim : window.dimensions()) {
+ window_size.push_back(dim.size());
+ CHECK_GT(dim.size(), 0);
+ }
+ const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
+ ShapeUtil::MakeShape(operand_element_type, window_size), "window");
+ llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
+ &ir_builder_);
+
+ // Compute the operand index to visit and evaluate the condition whether the
+ // operand index is within the bounds. The unsigned comparison includes
+ // checking whether the operand index >= 0.
+ llvm_ir::IrArray::Index operand_index(source_index.size());
+ llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* strided_index = ir_builder_.CreateNSWMul(
+ source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
+ operand_index[i] = ir_builder_.CreateNSWSub(
+ ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
+ ir_builder_.getInt64(window.dimensions(i).padding_low()));
+ llvm::Value* index_condition = ir_builder_.CreateICmpULT(
+ operand_index[i],
+ ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
+ in_bounds_condition =
+ ir_builder_.CreateAnd(in_bounds_condition, index_condition);
+ }
+ CHECK(in_bounds_condition != nullptr);
+
+ // Only need to do something if the operand index is within the bounds.
+ // First check if the initialized_flag is set.
+ llvm_ir::LlvmIfData if_in_bounds =
+ llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
+ llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
+ ir_builder_.CreateLoad(initialized_flag_address), "initialized",
+ &ir_builder_);
+
+ // If the initialized_flag is false, initialize the selected value and index
+ // with the currently visiting operand.
+ llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
+ const auto save_operand_index = [&](
+ const llvm_ir::IrArray::Index& operand_index) {
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* selected_index_address_slot =
+ ir_builder_.CreateInBoundsGEP(selected_index_address,
+ {ir_builder_.getInt32(i)});
+ ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
+ }
+ };
+ llvm_ir::IrArray operand_array(GetIrArray(*operand));
+ llvm::Value* operand_data =
+ operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
+ ir_builder_.CreateStore(operand_data, selected_value_address);
+ save_operand_index(operand_index);
+ ir_builder_.CreateStore(ir_builder_.getInt1(true),
+ initialized_flag_address);
+
+ // If the initialized_flag is true, call the `select` function to
+ // potentially update the selected value and index with the currently
+ // visiting operand.
+ llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
+ const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
+ llvm::Value* operand_address =
+ operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
+ llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
+ llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_),
+ "select_return_buffer", &ir_builder_);
+ TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
+ *select_and_scatter->select(),
+ {selected_value_address, operand_address}, select_return_buffer));
+ llvm::Value* result = ir_builder_.CreateLoad(select_return_buffer);
+
+ // If the 'select' function returns false, update the selected value and the
+ // index to the currently visiting operand.
+ llvm::Value* cond = ir_builder_.CreateICmpNE(
+ result, llvm::ConstantInt::get(
+ llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
+ "boolean_predicate");
+ llvm_ir::LlvmIfData if_select_lhs =
+ llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
+ llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
+ ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
+ selected_value_address);
+ save_operand_index(operand_index);
+
+ // After iterating over the window elements, scatter the source element to
+ // the selected index of the output. The value we store at the output
+ // location is computed by calling the `scatter` function with the source
+ // value and the current output value.
+ llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
+ &ir_builder_);
+ llvm_ir::IrArray::Index selected_index;
+ for (int64 i = 0; i < rank; ++i) {
+ llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
+ selected_index_address, {ir_builder_.getInt32(i)});
+ selected_index.push_back(
+ ir_builder_.CreateLoad(selected_index_address_slot));
+ }
+ llvm::Value* source_value_address =
+ GetIrArray(*source).EmitArrayElementAddress(source_index, &ir_builder_);
+ llvm::Value* output_value_address =
+ GetIrArray(*select_and_scatter)
+ .EmitArrayElementAddress(selected_index, &ir_builder_);
+ return EmitAtomicOperationForNestedComputation(
+ *select_and_scatter->scatter(), output_value_address,
+ source_value_address);
+ };
+
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ source->shape(), ir_emitter_context_->device_description());
+ UpdateLaunchDimensions(
+ launch_dimensions,
+ // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
+ // consisting of two thunks, an initializer KernelThunk that initializes
+ // the output and another KernelThunk that accumulates the scattered
+ // elements.
+ static_cast<SequentialThunk*>(LastThunk())->thunks().back().get(),
+ ir_emitter_context_->llvm_module());
+ return ParallelLoopEmitter(loop_body_emitter, source->shape(),
+ launch_dimensions, &ir_builder_)
+ .EmitLoop();
+}
+
+Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while,
+ HloInstruction* init,
+ HloComputation* condition,
+ HloComputation* body) {
+ TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
+ condition->root_instruction()->shape().element_type() == PRED)
+ << "While condition computation must return bool";
+ // Build ForThunk for conformant while loops, otherwise build WhileThunk.
+ auto result = CanTransformWhileToFor(xla_while);
+ if (result.ok()) {
+ auto tuple = result.ConsumeValueOrDie();
+ // loop_trip_count = (limit - start + increment - 1) / increment
+ const int64 loop_trip_count =
+ (std::get<1>(tuple) - std::get<0>(tuple) + std::get<2>(tuple) - 1) /
+ std::get<2>(tuple);
+ thunk_sequence_->emplace_back(BuildForThunk(xla_while, loop_trip_count));
+ VLOG(3) << "Built ForThunk for while: " << xla_while->name();
+ } else {
+ thunk_sequence_->emplace_back(BuildWhileThunk(xla_while));
+ VLOG(3) << "Built WhileThunk for while: " << xla_while->name()
+ << " while-to-for transform status: " << result.status();
+ }
+ return Status::OK();
+}
+
+Status IrEmitterUnnested::HandleRng(HloInstruction* random,
+ RandomDistribution distribution) {
+ thunk_sequence_->push_back(BuildKernelThunk(random));
+ return IrEmitter::HandleRng(random, distribution);
+}
+
+Status IrEmitterUnnested::HandleSelect(HloInstruction* select,
+ HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ thunk_sequence_->push_back(BuildKernelThunk(select));
+ return IrEmitter::HandleSelect(select, pred, on_true, on_false);
+}
+
+llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands(
+ const HloInstruction& hlo, std::vector<const HloInstruction*>* io_hlos) {
+ const BufferAssignment& buffer_assignment =
+ ir_emitter_context_->buffer_assignment();
+ // GetTupleElement instructions are implemented by emitting IR that indexes
+ // and loads the target tuple element pointer from its operand (possibly
+ // recursively). For this reason, GetTupleElement instructions are associated
+ // with their operand buffer in 'io_hlos' and 'non_io_hlos' below.
+ std::vector<const HloInstruction*> non_io_hlos;
+ for (const HloInstruction* operand : hlo.operands()) {
+ const HloInstruction* to_lookup = LatestNonGteAncestor(operand);
+ if (buffer_assignment.HasTopLevelAllocation(to_lookup) &&
+ buffer_assignment.GetUniqueTopLevelAllocation(to_lookup)
+ .ConsumeValueOrDie()
+ ->IsInputOrOutput()) {
+ io_hlos->push_back(operand);
+ } else {
+ non_io_hlos.push_back(operand);
+ }
+ }
+
+ CHECK_NE(HloOpcode::kGetTupleElement, hlo.opcode());
+ if (buffer_assignment.HasTopLevelAllocation(&hlo) &&
+ buffer_assignment.GetUniqueTopLevelAllocation(&hlo)
+ .ConsumeValueOrDie()
+ ->IsInputOrOutput()) {
+ io_hlos->push_back(&hlo);
+ } else {
+ non_io_hlos.push_back(&hlo);
+ }
+
+ llvm::Function* kernel = BuildKernelPrototype(hlo, *io_hlos);
+ // bindings_ is reused because the bindings of kConstant to their underlying
+ // llvm::Constant can be shared for all HLOs in this computation.
+ bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
+ return kernel;
+}
+
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
+ const HloInstruction* inst) {
+ std::vector<const HloInstruction*> io_hlos;
+ llvm::Function* kernel =
+ EmitBasePointersForHloAndItsOperands(*inst, &io_hlos);
+
+ // Compute the input buffer indices.
+ std::vector<BufferAllocation::Index> io_buffers;
+ for (const HloInstruction* io_hlo : io_hlos) {
+ io_buffers.push_back(GetAllocationIndex(*LatestNonGteAncestor(io_hlo)));
+ }
+
+ // Create a KernelThunk that launches the kernel that implements "inst".
+ return MakeUnique<KernelThunk>(io_buffers,
+ llvm_ir::AsString(kernel->getName()), inst);
+}
+
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildCopyThunk(
+ const HloInstruction* inst) {
+ const HloInstruction* operand = inst->operand(0);
+ CHECK_EQ(HloOpcode::kConstant, operand->opcode());
+ return MakeUnique<CopyThunk>(
+ /*source_address=*/LiteralUtil::InternalData(operand->literal()),
+ /*destination_buffer=*/GetAllocationIndex(*inst),
+ /*mem_size=*/llvm_ir::ByteSizeOf(
+ operand->shape(),
+ ir_emitter_context_->llvm_module()->getDataLayout()),
+ inst);
+}
+
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
+ const HloInstruction* inst) {
+ if (inst->opcode() == HloOpcode::kDot) {
+ const HloInstruction* lhs = inst->operand(0);
+ const HloInstruction* rhs = inst->operand(1);
+ return MakeUnique<GemmThunk>(
+ GetAllocationIndex(*lhs), // The buffer assigned to LHS.
+ GetAllocationIndex(*rhs), // The buffer assigned to RHS.
+ GetAllocationIndex(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ false, // Do not transpose LHS.
+ false, // Do not transpose RHS.
+ inst);
+ }
+
+ if (inst->opcode() == HloOpcode::kFusion) {
+ const HloInstruction* dot = inst->fused_expression_root();
+ DCHECK(dot->opcode() == HloOpcode::kDot);
+ const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
+ const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
+ DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
+ rhs_parameter->opcode() == HloOpcode::kParameter);
+ const HloInstruction* lhs =
+ inst->operand(lhs_parameter->parameter_number());
+ const HloInstruction* rhs =
+ inst->operand(rhs_parameter->parameter_number());
+
+ return MakeUnique<GemmThunk>(
+ GetAllocationIndex(*lhs), // The buffer assigned to LHS.
+ GetAllocationIndex(*rhs), // The buffer assigned to RHS.
+ GetAllocationIndex(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ dot->operand(0)->IsRank2Transpose(), // Transpose LHS.
+ dot->operand(1)->IsRank2Transpose(), // Trasnpose RHS.
+ inst);
+ }
+
+ LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString();
+}
+
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildConvolutionThunk(
+ const HloInstruction* inst) {
+ const HloInstruction* lhs = inst->operand(0);
+ const HloInstruction* rhs = inst->operand(1);
+ if (inst->opcode() == HloOpcode::kConvolution) {
+ // Forward covolution.
+ return MakeUnique<ConvolutionThunk>(
+ ConvolutionThunk::ConvolutionKind::kForward,
+ /*input_buffer=*/GetAllocationIndex(*lhs),
+ /*filter_buffer=*/GetAllocationIndex(*rhs),
+ /*output_buffer=*/GetAllocationIndex(*inst),
+ /*input_shape=*/lhs->shape(),
+ /*filter_shape=*/rhs->shape(),
+ /*output_shape=*/inst->shape(), inst->window(),
+ inst->convolution_dimension_numbers(), inst);
+ }
+
+ // Backward filter convolution, which takes the input (activations) and the
+ // gradients, and computes the filter.
+ CHECK_EQ(HloOpcode::kFusion, inst->opcode());
+ switch (inst->fusion_kind()) {
+ case HloInstruction::FusionKind::kConvBackwardFilter:
+ return MakeUnique<ConvolutionThunk>(
+ ConvolutionThunk::ConvolutionKind::kBackwardFilter,
+ /*input_buffer=*/GetAllocationIndex(*lhs),
+ /*filter_buffer=*/GetAllocationIndex(*inst),
+ /*output_buffer=*/GetAllocationIndex(*rhs),
+ /*input_shape=*/lhs->shape(),
+ /*filter_shape=*/inst->shape(),
+ /*output_shape=*/rhs->shape(), inst->window(),
+ inst->convolution_dimension_numbers(), inst);
+ case HloInstruction::FusionKind::kConvBackwardInput:
+ return MakeUnique<ConvolutionThunk>(
+ ConvolutionThunk::ConvolutionKind::kBackwardInput,
+ /*input_buffer=*/GetAllocationIndex(*inst),
+ /*filter_buffer=*/GetAllocationIndex(*rhs),
+ /*output_buffer=*/GetAllocationIndex(*lhs),
+ /*input_shape=*/inst->shape(),
+ /*filter_shape=*/rhs->shape(),
+ /*output_shape=*/lhs->shape(), inst->window(),
+ inst->convolution_dimension_numbers(), inst);
+ default:
+ LOG(FATAL) << "Not a convolution-fusion";
+ }
+}
+
+Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo,
+ KernelThunk* thunk) {
+ bool fused = HloOpcode::kFusion == hlo->opcode();
+
+ const HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
+ CHECK(inst->opcode() == HloOpcode::kSelectAndScatter ||
+ inst->opcode() == HloOpcode::kReduce);
+ const HloInstruction* init_value = nullptr;
+ switch (inst->opcode()) {
+ case HloOpcode::kSelectAndScatter:
+ init_value = inst->operand(2);
+ break;
+ case HloOpcode::kReduce:
+ init_value = inst->operand(1);
+ break;
+ default:
+ LOG(FATAL) << "Opcode " << inst->opcode()
+ << " should not need an initializer.";
+ }
+
+ if (fused && init_value->opcode() == HloOpcode::kParameter) {
+ init_value = hlo->operand(init_value->parameter_number());
+ }
+
+ return EmitTargetElementLoopInThunk(
+ *hlo,
+ [=](const llvm_ir::IrArray::Index& index) {
+ return GetIrArray(*init_value)
+ .EmitReadArrayElement(index, &ir_builder_);
+ },
+ thunk);
+}
+
+namespace {
+
+// Checks that all buffers used during while loop iteration share the same
+// buffer allocation. This includes buffers for while result, while init
+// operand, condition parameter, body parameter and body result.
+// Returns OK on success, error status otherwise.
+Status CheckWhileBuffersShareAllocation(
+ const HloInstruction* xla_while,
+ const BufferAssignment& buffer_assignment) {
+ return ShapeUtil::ForEachSubshape(
+ xla_while->shape(),
+ [&buffer_assignment, &xla_while](const Shape& /*subshape*/,
+ const ShapeIndex& index) -> Status {
+ auto check = [&buffer_assignment](const HloInstruction* a,
+ const HloInstruction* b,
+ const ShapeIndex& index) -> Status {
+ BufferAllocation::Index index_a =
+ buffer_assignment.GetUniqueAllocation(a, index)
+ .ConsumeValueOrDie()
+ ->index();
+ BufferAllocation::Index index_b =
+ buffer_assignment.GetUniqueAllocation(b, index)
+ .ConsumeValueOrDie()
+ ->index();
+ if (index_a != index_b) {
+ return InternalError(
+ "instruction %s does not share allocation with "
+ "instruction %s ",
+ a->ToString().c_str(), b->ToString().c_str());
+ }
+ return Status::OK();
+ };
+ const HloInstruction* condition_parameter =
+ xla_while->while_condition()->parameter_instruction(0);
+ const HloComputation* body = xla_while->while_body();
+ const HloInstruction* body_parameter = body->parameter_instruction(0);
+ const HloInstruction* body_result = body->root_instruction();
+ TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index));
+ TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index));
+ TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index));
+ TF_RETURN_IF_ERROR(check(xla_while, body_result, index));
+ return Status::OK();
+ });
+}
+
+} // namespace
+
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
+ const HloInstruction* hlo) {
+ // Check that all while-related buffers share an allocation.
+ TF_CHECK_OK(CheckWhileBuffersShareAllocation(
+ hlo, ir_emitter_context_->buffer_assignment()));
+
+ // Generate thunk sequence for while 'condition'.
+ HloComputation* condition = hlo->while_condition();
+ IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
+ /*has_hybrid_result=*/false,
+ ir_emitter_context_);
+ TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition));
+
+ // Generate thunk sequence for while 'body'.
+ HloComputation* body = hlo->while_body();
+ IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
+ false /* has_hybrid_result */,
+ ir_emitter_context_);
+ TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body));
+
+ return MakeUnique<WhileThunk>(
+ GetAllocationIndex(*condition->root_instruction()), // cond result
+ ir_emitter_condition.ConsumeThunkSequence(),
+ ir_emitter_body.ConsumeThunkSequence(), hlo);
+}
+
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
+ const HloInstruction* hlo, const int64 loop_limit) {
+ // Check that all while-related buffers share an allocation.
+ TF_CHECK_OK(CheckWhileBuffersShareAllocation(
+ hlo, ir_emitter_context_->buffer_assignment()));
+
+ // Generate thunk sequence for while 'body' (will be used a For loop body).
+ HloComputation* body = hlo->while_body();
+ IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
+ false /* has_hybrid_result */,
+ ir_emitter_context_);
+ TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body));
+
+ return MakeUnique<ForThunk>(loop_limit,
+ ir_emitter_body.ConsumeThunkSequence(), hlo);
+}
+
+Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ hlo.shape(), ir_emitter_context_->device_description());
+ UpdateLaunchDimensions(launch_dimensions, thunk,
+ ir_emitter_context_->llvm_module());
+ // Otherwise, emit a parallel loop that computes the partition that each
+ // thread is in charge of.
+ return ParallelLoopEmitter(element_generator, GetIrArray(hlo),
+ launch_dimensions, &ir_builder_)
+ .EmitLoop();
+}
+
+Status IrEmitterUnnested::EmitTargetElementLoop(
+ const HloInstruction& hlo,
+ const llvm_ir::ElementGenerator& element_generator) {
+ CHECK(Thunk::Kind::kKernel == LastThunk()->kind());
+ return EmitTargetElementLoopInThunk(hlo, element_generator,
+ static_cast<KernelThunk*>(LastThunk()));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
new file mode 100644
index 0000000000..14760fe92c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+using Index = BufferAllocation::Index;
+
+KernelThunk::KernelThunk(tensorflow::gtl::ArraySlice<Index> io_buffers,
+ const string& kernel_name,
+ const HloInstruction* hlo_instruction)
+ : Thunk(Kind::kKernel, hlo_instruction),
+ io_buffers_(io_buffers.begin(), io_buffers.end()),
+ kernel_name_(kernel_name) {}
+
+tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) {
+ tensorflow::mutex_lock lock(mutex_);
+ if (loader_spec_) {
+ // Already initialized by another thread.
+ return tensorflow::Status::OK();
+ }
+
+ loader_spec_.reset(new se::MultiKernelLoaderSpec(io_buffers_.size() + 1));
+ tensorflow::StringPiece ptx = executable.ptx();
+ // Convert tensorflow::StringPiece to se::port::StringPiece because
+ // StreamExecutor uses the latter.
+ loader_spec_->AddCudaPtxInMemory(
+ se::port::StringPiece(ptx.data(), ptx.size()), kernel_name_);
+ return tensorflow::Status::OK();
+}
+
+void KernelThunk::SetLaunchDimensions(const LaunchDimensions& launch_dims) {
+ tensorflow::mutex_lock lock(mutex_);
+ launch_dimensions_ = launch_dims;
+}
+
+tensorflow::Status KernelThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ // Load the kernel.
+ se::StreamExecutor* executor = stream->parent();
+ se::KernelBase kernel(executor);
+ LaunchDimensions launch_dimensions;
+ {
+ tensorflow::mutex_lock lock(mutex_);
+ if (!executor->GetKernel(*loader_spec_, &kernel)) {
+ return InternalError("Unable to load kernel %s", kernel_name_.c_str());
+ }
+ launch_dimensions = launch_dimensions_;
+ }
+
+ // Launch the kernel with potentially multiple blocks and threads.
+ static constexpr int kKernelArgsLimit = 1024;
+ auto kernel_args = MakeUnique<se::KernelArgsArray<kKernelArgsLimit>>();
+ for (BufferAllocation::Index io_buffer : io_buffers_) {
+ kernel_args->add_device_memory_argument(
+ buffer_allocations.GetDeviceAddress(io_buffer));
+ }
+ kernel_args->add_device_memory_argument(
+ buffer_allocations.GetTempBufferBase());
+ if (!stream->parent()->Launch(
+ stream, se::ThreadDim(launch_dimensions.threads_per_block()),
+ se::BlockDim(launch_dimensions.block_count()), kernel,
+ *kernel_args)) {
+ return InternalError("Unable to launch kernel %s", kernel_name_.c_str());
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
new file mode 100644
index 0000000000..901825873a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h
@@ -0,0 +1,86 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_THUNK_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuExecutable;
+
+// This class stores everything that StreamExecutor needs for launching a
+// kernel. It implements the ExecuteOnStream interface for GpuExecutable to
+// invoke the corresponding kernel.
+//
+// This is thread-compatible.
+class KernelThunk : public Thunk {
+ public:
+ // Constructs a thunk for the given kernel.
+ //
+ // `hlo_instruction` is as in Thunk. Other arguments are as the class members.
+ KernelThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Index> io_buffers,
+ const string& kernel_name, const HloInstruction* hlo_instruction);
+ KernelThunk(const KernelThunk&) = delete;
+ KernelThunk& operator=(const KernelThunk&) = delete;
+ ~KernelThunk() override = default;
+
+ const string& kernel_name() const { return kernel_name_; }
+ void SetLaunchDimensions(const LaunchDimensions& launch_dims);
+
+ tensorflow::Status Initialize(const GpuExecutable& executable) override;
+
+ // Executes the kernel for the thunk on "stream", which must be non-null.
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ // The indices of the input/output buffers.
+ const std::vector<BufferAllocation::Index> io_buffers_;
+
+ // Entry kernel name for the computation.
+ const string kernel_name_;
+
+ // The thread and block dimension used to launch the kernel.
+ // Will be set by IrEmitterUnnested.
+ LaunchDimensions launch_dimensions_;
+
+ // Describes how to load this kernel. ExecuteOnStream reuses this loader
+ // specification for all executions.
+ mutable tensorflow::mutex mutex_;
+ std::unique_ptr<perftools::gputools::MultiKernelLoaderSpec> loader_spec_
+ GUARDED_BY(mutex_);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_KERNEL_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc
new file mode 100644
index 0000000000..ff6cfd9448
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h"
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+Status GpuLayoutAssignment::AddBackendConstraints(
+ LayoutConstraints* constraints) {
+ for (auto& instruction : constraints->computation()->instructions()) {
+ // cuDNN is called with specific layouts on the input, output, and filter:
+ //
+ // input: DataLayout::kBatchDepthYX
+ // output: DataLayout::kBatchDepthYX
+ // filter: FilterLayout::kOutputInputYX
+ //
+ // The order dimensions in the constant name is major-to-minor (eg, the
+ // most-major dimension of the input is batch, most-minor is X). The
+ // specific dimension numbers these named dimensions correspond to is
+ // determined by the ConvolutionDimensionNumbers argument. Y is spatial
+ // dimension 0, and X is spatial dimension 1.
+ //
+ // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls.
+ if (ImplementedAsDnnConvolution(*instruction)) {
+ HloInstruction* input = nullptr;
+ HloInstruction* filter = nullptr;
+ HloInstruction* output = nullptr;
+ if (instruction->opcode() == HloOpcode::kConvolution) {
+ input = instruction->mutable_operand(0);
+ filter = instruction->mutable_operand(1);
+ output = instruction.get();
+ } else {
+ CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
+ switch (instruction->fusion_kind()) {
+ case HloInstruction::FusionKind::kConvBackwardFilter:
+ // filter = BackwardFilterConvolve(input, output)
+ input = instruction->mutable_operand(0);
+ filter = instruction.get();
+ output = instruction->mutable_operand(1);
+ break;
+ case HloInstruction::FusionKind::kConvBackwardInput:
+ // input = BackwardInputConvolve(output, filter)
+ input = instruction.get();
+ filter = instruction->mutable_operand(1);
+ output = instruction->mutable_operand(0);
+ break;
+ default:
+ LOG(FATAL) << "Not a convolution-fusion";
+ }
+ }
+
+ // Construct minor-to-major dimension orders for operands and result.
+ // cuDNN's convolution APIs support the BDYX layout for activations/output
+ // and the OIYX layout for weights.
+ // TODO(b/29399649): Be more flexible about handling layouts of cuDNN
+ // calls after we switch to cuDNN v5.
+ const ConvolutionDimensionNumbers& dimension_numbers =
+ instruction->convolution_dimension_numbers();
+ Shape input_shape(input->shape());
+ *input_shape.mutable_layout() =
+ LayoutUtil::MakeLayout({dimension_numbers.spatial_dimensions(1),
+ dimension_numbers.spatial_dimensions(0),
+ dimension_numbers.feature_dimension(),
+ dimension_numbers.batch_dimension()});
+
+ Shape filter_shape(filter->shape());
+ *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(
+ {dimension_numbers.kernel_spatial_dimensions(1),
+ dimension_numbers.kernel_spatial_dimensions(0),
+ dimension_numbers.kernel_input_feature_dimension(),
+ dimension_numbers.kernel_output_feature_dimension()});
+
+ Shape output_shape(output->shape());
+ *output_shape.mutable_layout() =
+ LayoutUtil::MakeLayout({dimension_numbers.spatial_dimensions(1),
+ dimension_numbers.spatial_dimensions(0),
+ dimension_numbers.feature_dimension(),
+ dimension_numbers.batch_dimension()});
+
+ // Set layouts of the instructions' shapes.
+ if (instruction->opcode() == HloOpcode::kConvolution) {
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(input_shape, output, 0));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(filter_shape, output, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(output_shape, output));
+ } else {
+ CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
+ switch (instruction->fusion_kind()) {
+ case HloInstruction::FusionKind::kConvBackwardFilter:
+ // filter = BackwardFilterConvolve(input, output)
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(input_shape, filter, 0));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(filter_shape, filter));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(output_shape, filter, 1));
+ break;
+ case HloInstruction::FusionKind::kConvBackwardInput:
+ // input = BackwardInputConvolve(output, filter)
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(input_shape, input));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(output_shape, input, 0));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(filter_shape, input, 1));
+ break;
+ default:
+ LOG(FATAL) << "Not a convolution-fusion";
+ }
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/layout_assignment.h
new file mode 100644
index 0000000000..169041eb85
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_
+
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+namespace gpu {
+
+// GPU-specific layout assignment pass which preassigns layouts to satisfy
+// layout constraints for operands and results of library calls.
+class GpuLayoutAssignment : public LayoutAssignment {
+ public:
+ explicit GpuLayoutAssignment(ComputationLayout* entry_computation_layout)
+ : LayoutAssignment(entry_computation_layout) {}
+ ~GpuLayoutAssignment() override {}
+
+ protected:
+ Status AddBackendConstraints(LayoutConstraints* constraints) override;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc
new file mode 100644
index 0000000000..692ec8147d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+using LayoutAssignmentTest = HloTestBase;
+
+TEST_F(LayoutAssignmentTest, Elementwise) {
+ Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
+ Shape ashape_in_row_major(ashape);
+ Shape ashape_in_col_major(ashape);
+ *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ // Enumerate all possible combinations of layouts.
+ for (const Shape& lhs_shape_with_layout :
+ {ashape_in_row_major, ashape_in_col_major}) {
+ for (const Shape& rhs_shape_with_layout :
+ {ashape_in_row_major, ashape_in_col_major}) {
+ for (const Shape& result_shape_with_layout :
+ {ashape_in_row_major, ashape_in_col_major}) {
+ // GpuLayoutAssignment should assign the same layout to "add" and its
+ // two operands.
+ auto builder = HloComputation::Builder(TestName());
+ auto x = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "x"));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ashape, "y"));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y));
+ HloModule module(TestName());
+ HloComputation* computation =
+ module.AddEntryComputation(builder.Build(add));
+
+ ComputationLayout computation_layout(
+ computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(lhs_shape_with_layout);
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(rhs_shape_with_layout);
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(result_shape_with_layout);
+
+ GpuLayoutAssignment layout_assignment(&computation_layout);
+ EXPECT_TRUE(layout_assignment.Run(&module).ValueOrDie());
+
+ for (const HloInstruction* operand : add->operands()) {
+ EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(),
+ operand->shape().layout()));
+ }
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
new file mode 100644
index 0000000000..fc0970049a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD
@@ -0,0 +1,88 @@
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [":friends"],
+ features = [
+ "-parse_headers",
+ "no_layering_check",
+ ],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+cc_library(
+ name = "llvm_gpu_backend",
+ srcs = [
+ "dump_ir_pass.cc",
+ "gpu_backend_lib.cc",
+ "utils.cc",
+ ],
+ hdrs = [
+ "dump_ir_pass.h",
+ "gpu_backend_lib.h",
+ "utils.h",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/legacy_flags:gpu_backend_lib_flags",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "@llvm//:analysis",
+ "@llvm//:asm_printer",
+ "@llvm//:bit_reader",
+ "@llvm//:bit_writer",
+ "@llvm//:code_gen",
+ "@llvm//:core",
+ "@llvm//:instrumentation",
+ "@llvm//:ipo",
+ "@llvm//:ir_reader",
+ "@llvm//:linker",
+ "@llvm//:mc",
+ "@llvm//:nvptx_code_gen",
+ "@llvm//:objc_arc",
+ "@llvm//:support",
+ "@llvm//:target",
+ "@llvm//:transform_utils",
+ ],
+)
+
+cc_test(
+ name = "utils_test",
+ size = "small",
+ srcs = ["utils_test.cc"],
+ data = [
+ "tests_data/saxpy.ll",
+ ],
+ deps = [
+ ":llvm_gpu_backend",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "@llvm//:core",
+ "@llvm//:support",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
new file mode 100644
index 0000000000..aeec3a03ca
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.cc
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
+
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/Support/FileSystem.h"
+#include "external/llvm/include/llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+
+// Pass which dumps the IR of a module into a file.
+//
+// Because it is implemented as a FunctionPass (IR is dumped
+// function-by-function) rather than as a ModulePass the resulting IR is not
+// valid (missing metadata, for example) but is still useful for inspection.
+// The pass needs to be a FunctionPass rather than a ModulePass because
+// inserting ModulePasses is disruptive to LLVM's pass manager. For sequential
+// FunctionPasses (also SCC passes, etc) the pass manager executes the passes
+// sequentially on each function (SCC, etc). Inserting a ModulePass between
+// FunctionPasses acts as a barrier forcing the FunctionPasses to execute fully
+// across all functions prior to advancing to the next pass. For some reason
+// this results in different generated code resulting in an undesirable
+// Heisenberg effect when dumping the IR.
+class DumpIrPass : public llvm::FunctionPass {
+ public:
+ explicit DumpIrPass(const string &output_filename)
+ : llvm::FunctionPass(id_), output_filename_(output_filename) {}
+
+ bool doInitialization(llvm::Module &M) override {
+ out_.reset(new llvm::raw_fd_ostream(llvm::StringRef(output_filename_), ec_,
+ llvm::sys::fs::F_None));
+ if (ec_) {
+ LOG(FATAL) << "Unable to open " << output_filename_
+ << " to dump LLVM IR: " << ec_.message();
+ }
+ return false;
+ }
+
+ bool runOnFunction(llvm::Function &Function) override {
+ Function.print(*out_);
+ return false;
+ }
+
+ void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+ AU.setPreservesAll();
+ }
+
+ bool doFinalization(llvm::Module &M) override {
+ out_->close();
+ return false;
+ }
+
+ private:
+ static char id_;
+ string output_filename_;
+ std::error_code ec_;
+ std::unique_ptr<llvm::raw_fd_ostream> out_;
+};
+
+char DumpIrPass::id_ = 0;
+
+void IrDumpingPassManager::run(llvm::Module &module) {
+ for (int i = 0; i < passes_.size(); ++i) {
+ llvm::Pass *P = passes_[i];
+ if (dump_ir_) {
+ const llvm::PassInfo *PI =
+ llvm::PassRegistry::getPassRegistry()->getPassInfo(P->getPassID());
+ const string basename = ReplaceFilenameExtension(
+ tensorflow::io::Basename(input_filename_),
+ tensorflow::strings::Printf(
+ "pass-%02d.before.%s.ll", i,
+ (PI == nullptr ? "unknown" : PI->getPassArgument().data())));
+ llvm::legacy::PassManager::add(
+ new DumpIrPass(tensorflow::io::JoinPath(output_dir_, basename)));
+ }
+ llvm::legacy::PassManager::add(P);
+ }
+ passes_.clear();
+ llvm::legacy::PassManager::run(module);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h
new file mode 100644
index 0000000000..1d515a0f28
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_DUMP_IR_PASS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_DUMP_IR_PASS_H_
+
+#include <string>
+
+#include "external/llvm/include/llvm/IR/LegacyPassManager.h"
+#include "external/llvm/include/llvm/Pass.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+// Pass manager which optionally dumps the IR to a sequence of files before each
+// pass.
+class IrDumpingPassManager : public llvm::legacy::PassManager {
+ public:
+ IrDumpingPassManager(const string& input_filename, const string& output_dir,
+ bool dump_ir)
+ : llvm::legacy::PassManager(),
+ input_filename_(input_filename),
+ output_dir_(output_dir),
+ dump_ir_(dump_ir) {}
+ void add(llvm::Pass* P) { passes_.push_back(P); }
+ void run(llvm::Module& module); // NOLINT(runtime/references)
+
+ private:
+ string input_filename_;
+ string output_dir_;
+ bool dump_ir_;
+ std::vector<llvm::Pass*> passes_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_DUMP_IR_PASS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
new file mode 100644
index 0000000000..e15659938c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc
@@ -0,0 +1,489 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h"
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/util.h"
+
+#include "external/llvm/include/llvm/ADT/STLExtras.h"
+#include "external/llvm/include/llvm/ADT/StringMap.h"
+#include "external/llvm/include/llvm/ADT/StringSet.h"
+#include "external/llvm/include/llvm/Analysis/TargetLibraryInfo.h"
+#include "external/llvm/include/llvm/Analysis/TargetTransformInfo.h"
+#include "external/llvm/include/llvm/Bitcode/BitcodeReader.h"
+#include "external/llvm/include/llvm/Bitcode/BitcodeWriter.h"
+#include "external/llvm/include/llvm/CodeGen/CommandFlags.h"
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "external/llvm/include/llvm/IR/LegacyPassManager.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/LinkAllIR.h"
+#include "external/llvm/include/llvm/LinkAllPasses.h"
+#include "external/llvm/include/llvm/Linker/Linker.h"
+#include "external/llvm/include/llvm/PassRegistry.h"
+#include "external/llvm/include/llvm/Support/CommandLine.h"
+#include "external/llvm/include/llvm/Support/FileSystem.h"
+#include "external/llvm/include/llvm/Support/FormattedStream.h"
+#include "external/llvm/include/llvm/Support/TargetRegistry.h"
+#include "external/llvm/include/llvm/Support/TargetSelect.h"
+#include "external/llvm/include/llvm/Support/ToolOutputFile.h"
+#include "external/llvm/include/llvm/Target/TargetMachine.h"
+#include "external/llvm/include/llvm/Transforms/IPO.h"
+#include "external/llvm/include/llvm/Transforms/IPO/AlwaysInliner.h"
+#include "external/llvm/include/llvm/Transforms/IPO/PassManagerBuilder.h"
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+// Default inline threshold value to use in llvm.
+const int kDefaultInlineThreshold = 1100;
+
+// Information about a GPU architecture for the backend.
+struct GpuBackendInfo {
+ string libdevice_name;
+ string sm_name;
+};
+
+// Maps supported CUDA compute capability to a libdevice file to link for this
+// capability.
+std::map<string, GpuBackendInfo> gpu_info_map = {
+ {"compute_20", {"libdevice.compute_20.10.bc", "sm_20"}},
+ {"compute_30", {"libdevice.compute_30.10.bc", "sm_30"}},
+ {"compute_35", {"libdevice.compute_35.10.bc", "sm_35"}},
+
+ // NVIDIA does not provide a separate libdevice for CC 3.7, but we can use
+ // the one for 3.5.
+ {"compute_37", {"libdevice.compute_35.10.bc", "sm_37"}},
+};
+
+// Validate the --gpu_architecture command-line flag.
+static void ValidateGPUArchitecture(const string& value) {
+ if (!gpu_info_map.count(value)) {
+ LOG(FATAL) << "value for --gpu_architecture must be compute_{20,30,35,37}";
+ }
+}
+
+// Convenience function for producing a name of a temporary compilation product
+// from the input filename.
+string MakeNameForTempProduct(const std::string& input_filename,
+ tensorflow::StringPiece extension) {
+ legacy_flags::GpuBackendLibFlags* flags =
+ legacy_flags::GetGpuBackendLibFlags();
+ return tensorflow::io::JoinPath(
+ flags->dump_temp_products_to,
+ ReplaceFilenameExtension(
+ tensorflow::io::Basename(llvm_ir::AsString(input_filename)),
+ extension));
+}
+
+// Initializes LLVM passes. Uses the PassRegistry mechanism.
+void InitializePasses(llvm::PassRegistry* pass_registry) {
+ llvm::initializeCore(*pass_registry);
+ llvm::initializeCodeGen(*pass_registry);
+ llvm::initializeScalarOpts(*pass_registry);
+ llvm::initializeObjCARCOpts(*pass_registry);
+ llvm::initializeVectorization(*pass_registry);
+ llvm::initializeIPO(*pass_registry);
+ llvm::initializeAnalysis(*pass_registry);
+ llvm::initializeTransformUtils(*pass_registry);
+ llvm::initializeInstCombine(*pass_registry);
+ llvm::initializeInstrumentation(*pass_registry);
+ llvm::initializeTarget(*pass_registry);
+ llvm::initializeCodeGenPreparePass(*pass_registry);
+}
+
+// Returns the TargetMachine, given a triple.
+std::unique_ptr<llvm::TargetMachine> GetTargetMachine(
+ llvm::Triple triple, tensorflow::StringPiece cpu_name) {
+ std::string error;
+ const llvm::Target* target = TargetRegistry::lookupTarget("", triple, error);
+ if (target == nullptr) {
+ LOG(FATAL) << "Unable to find Target for triple '" << triple.str() << "'"
+ << " -- " << error;
+ return nullptr;
+ }
+
+ TargetOptions target_options = InitTargetOptionsFromCodeGenFlags();
+ // Enable FMA synthesis if desired.
+ legacy_flags::GpuBackendLibFlags* flags =
+ legacy_flags::GetGpuBackendLibFlags();
+ if (flags->fma) {
+ target_options.AllowFPOpFusion = FPOpFusion::Fast;
+ }
+
+ // Set options from LlvmBackendFlags (specifically, fast-math flags).
+ llvm_ir::SetTargetOptions(&target_options);
+
+ // Set the verbose assembly options.
+ target_options.MCOptions.AsmVerbose = flags->verbose_ptx_asm;
+
+ // The selection of codegen optimization level is copied from function
+ // GetCodeGenOptLevel in //external/llvm/tools/opt/opt.cpp.
+ CodeGenOpt::Level codegen_opt_level;
+ switch (flags->opt_level) {
+ case 1:
+ codegen_opt_level = CodeGenOpt::Less;
+ break;
+ case 2:
+ codegen_opt_level = CodeGenOpt::Default;
+ break;
+ case 3:
+ codegen_opt_level = CodeGenOpt::Aggressive;
+ break;
+ default:
+ codegen_opt_level = CodeGenOpt::None;
+ }
+ return WrapUnique(target->createTargetMachine(
+ triple.str(), llvm_ir::AsStringRef(cpu_name), "+ptx42", target_options,
+ Optional<Reloc::Model>(RelocModel), CMModel, codegen_opt_level));
+}
+
+// Adds the standard LLVM optimization passes, based on the speed optimization
+// level (opt_level) and size optimization level (size_level). Both module
+// and function-level passes are added, so two pass managers are passed in and
+// modified by this function.
+void AddOptimizationPasses(unsigned opt_level, unsigned size_level,
+ llvm::TargetMachine* target_machine,
+ llvm::legacy::PassManagerBase* module_passes,
+ llvm::legacy::FunctionPassManager* function_passes) {
+ PassManagerBuilder builder;
+ builder.OptLevel = opt_level;
+ builder.SizeLevel = size_level;
+
+ if (opt_level > 1) {
+ builder.Inliner = llvm::createFunctionInliningPass(kDefaultInlineThreshold);
+ } else {
+ // Only inline functions marked with "alwaysinline".
+ builder.Inliner = llvm::createAlwaysInlinerLegacyPass();
+ }
+
+ builder.DisableUnitAtATime = false;
+ builder.DisableUnrollLoops = opt_level == 0;
+ builder.LoopVectorize = opt_level > 0;
+ builder.SLPVectorize = opt_level > 1 && size_level < 2;
+
+ // NVPTX's early-as-possible passes include NVVM reflect.
+ builder.addExtension(
+ llvm::PassManagerBuilder::EP_EarlyAsPossible,
+ [&](const PassManagerBuilder&, legacy::PassManagerBase& pass_manager) {
+ target_machine->addEarlyAsPossiblePasses(pass_manager);
+ });
+
+ builder.populateFunctionPassManager(*function_passes);
+ builder.populateModulePassManager(*module_passes);
+}
+
+// Emits the given module to a bit code file.
+void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) {
+ std::error_code error_code;
+ llvm::tool_output_file outfile(filename.ToString().c_str(), error_code,
+ llvm::sys::fs::F_None);
+ if (error_code) {
+ LOG(FATAL) << "opening bitcode file for writing: " << error_code.message();
+ }
+
+ llvm::WriteBitcodeToFile(&module, outfile.os());
+ outfile.keep();
+}
+
+// Emits the given module to PTX. target_machine is an initialized TargetMachine
+// for the NVPTX target.
+string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) {
+ std::string ptx; // need a std::string instead of a ::string.
+ {
+ llvm::raw_string_ostream stream(ptx);
+ llvm::buffer_ostream pstream(stream);
+ // The extension is stripped by IrDumpingPassManager, so we need to
+ // get creative to add a suffix.
+ string module_id(llvm_ir::AsString(module->getModuleIdentifier()));
+ legacy_flags::GpuBackendLibFlags* flags =
+ legacy_flags::GetGpuBackendLibFlags();
+ IrDumpingPassManager codegen_passes(
+ ReplaceFilenameExtension(tensorflow::io::Basename(module_id),
+ "-nvptx.dummy"),
+ flags->dump_temp_products_to, flags->dump_ir_before_passes);
+ codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass(
+ llvm::Triple(module->getTargetTriple())));
+
+ target_machine->addPassesToEmitFile(codegen_passes, pstream,
+ llvm::TargetMachine::CGFT_AssemblyFile);
+ codegen_passes.run(*module);
+ }
+
+ return ptx;
+}
+
+// LLVM has an extensive flags mechanism of its own, which is only accessible
+// through the command line. Internal libraries within LLVM register parsers for
+// flags, with no other way to configure them except pass these flags.
+// To do this programmatically, we invoke ParseCommandLineOptions manually with
+// a "fake argv".
+// Note: setting flags with this method is stateful, since flags are just
+// static globals within LLVM libraries.
+void FeedLLVMWithFlags(const std::vector<string>& cl_opts) {
+ std::vector<const char*> fake_argv = {""};
+ for (const string& cl_opt : cl_opts) {
+ fake_argv.push_back(cl_opt.c_str());
+ }
+ llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]);
+}
+
+namespace {
+// Returns whether the module could use any libdevice functions. This function
+// may have false positives -- the module might not use libdevice even if this
+// function returns true.
+bool CouldNeedLibdevice(const llvm::Module& module) {
+ for (const llvm::Function& function : module.functions()) {
+ // This is a conservative approximation -- not all such functions are in
+ // libdevice.
+ if (!function.isIntrinsic() && function.isDeclaration()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+// Links libdevice into the given module if the module needs libdevice.
+tensorflow::Status LinkLibdeviceIfNecessary(const string& libdevice_dir_path,
+ llvm::Module* module) {
+ if (!CouldNeedLibdevice(*module)) {
+ return tensorflow::Status::OK();
+ }
+
+ llvm::Linker linker(*module);
+ legacy_flags::GpuBackendLibFlags* flags =
+ legacy_flags::GetGpuBackendLibFlags();
+ ValidateGPUArchitecture(flags->gpu_architecture);
+ string libdevice_bc_filename =
+ gpu_info_map[flags->gpu_architecture].libdevice_name;
+ string libdevice_bc_fullpath =
+ tensorflow::io::JoinPath(libdevice_dir_path, libdevice_bc_filename);
+ TF_RETURN_IF_ERROR(
+ tensorflow::Env::Default()->FileExists(libdevice_bc_fullpath));
+ std::unique_ptr<llvm::Module> libdevice_module =
+ LoadIRModule(libdevice_bc_fullpath, &module->getContext());
+ VLOG(1) << "Linking with libdevice from: " << libdevice_bc_fullpath;
+ if (linker.linkInModule(std::move(libdevice_module),
+ llvm::Linker::Flags::InternalizeLinkedSymbols |
+ llvm::Linker::Flags::LinkOnlyNeeded)) {
+ LOG(FATAL) << "Error linking libdevice from " << libdevice_bc_fullpath;
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+StatusOr<string> CompileModuleToPtx(llvm::Module* module,
+ const string& libdevice_dir_path) {
+ // Link the input module with libdevice, to pull in implementations of some
+ // builtins.
+ TF_RETURN_IF_ERROR(LinkLibdeviceIfNecessary(libdevice_dir_path, module));
+
+ legacy_flags::GpuBackendLibFlags* flags =
+ legacy_flags::GetGpuBackendLibFlags();
+ if (!flags->dump_temp_products_to.empty()) {
+ string linked_filename =
+ MakeNameForTempProduct(module->getModuleIdentifier(), "linked.bc");
+ LOG(INFO) << "dumping bitcode after linking libdevice to: "
+ << linked_filename;
+ EmitBitcodeToFile(*module, linked_filename);
+ }
+
+ // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass
+ // can access it.
+ module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", flags->ftz);
+
+ // If ftz is enabled, set it as an attribute on every function in the module.
+ if (flags->ftz) {
+ for (llvm::Function& fn : *module) {
+ fn.addFnAttr("nvptx-f32ftz", "true");
+ }
+ }
+
+ // Run IR-level optimizations.
+ if (flags->dump_ir_before_passes && flags->dump_temp_products_to.empty()) {
+ LOG(FATAL) << "--dump_ir_before_passes must be specified with "
+ "--dump_temp_products_to";
+ }
+
+ IrDumpingPassManager module_passes(module->getModuleIdentifier(),
+ flags->dump_temp_products_to,
+ flags->dump_ir_before_passes);
+
+ // Add an appropriate TargetLibraryInfo pass for the module's triple.
+ llvm::TargetLibraryInfoWrapperPass* tliwp =
+ new llvm::TargetLibraryInfoWrapperPass(
+ llvm::Triple(module->getTargetTriple()));
+ module_passes.add(tliwp);
+
+ // Try to fetch the target triple from the module. If not present, set a
+ // default target triple.
+ llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
+ if (target_triple.getArch() == llvm::Triple::UnknownArch) {
+ LOG(WARNING) << "target triple not found in the module";
+ target_triple = llvm::Triple("nvptx64-unknown-unknown");
+ }
+
+ // Figure out the exact name of the processor as known to the NVPTX backend
+ // from the gpu_architecture flag.
+ ValidateGPUArchitecture(flags->gpu_architecture);
+ string cpu_name = gpu_info_map[flags->gpu_architecture].sm_name;
+
+ std::unique_ptr<llvm::TargetMachine> target_machine =
+ GetTargetMachine(target_triple, cpu_name);
+ module_passes.add(llvm::createTargetTransformInfoWrapperPass(
+ target_machine->getTargetIRAnalysis()));
+
+ // The LLVM IR verifier performs sanity checking on the IR. This helps
+ // discover problems and report them in a meaningful manner, rather than let
+ // later passes report obscure assertions becasue of unfulfilled invariants.
+ module_passes.add(llvm::createVerifierPass());
+
+ // Create the function-level pass manager. It needs data layout information
+ // too.
+ llvm::legacy::FunctionPassManager function_passes(module);
+
+ AddOptimizationPasses(flags->opt_level, /*size_level=*/0,
+ target_machine.get(), &module_passes, &function_passes);
+ // Loop unrolling exposes more opportunites for SROA. Therefore, we run SROA
+ // again after the standard optimization passes [http://b/13329423].
+ // TODO(jingyue): SROA may further expose more optimization opportunites, such
+ // as more precise alias analysis and more function inlining (SROA may change
+ // the inlining cost of a function). For now, running SROA already emits good
+ // enough code for the evaluated benchmarks. We may want to run more
+ // optimizations later.
+ if (flags->opt_level > 0) {
+ // LLVM's optimizer turns on SROA when the optimization level is greater
+ // than 0. We mimic this behavior here.
+ module_passes.add(llvm::createSROAPass());
+ }
+
+ // Verify that the module is well formed after optimizations ran.
+ module_passes.add(llvm::createVerifierPass());
+
+ // Done populating the pass managers. Now run them.
+
+ function_passes.doInitialization();
+ for (auto func = module->begin(); func != module->end(); ++func) {
+ function_passes.run(*func);
+ }
+ function_passes.doFinalization();
+ module_passes.run(*module);
+
+ if (!flags->dump_temp_products_to.empty()) {
+ string optimized_filename =
+ MakeNameForTempProduct(module->getModuleIdentifier(), "optimized.bc");
+ LOG(INFO) << "dumping bitcode after optimizations to: "
+ << optimized_filename;
+ EmitBitcodeToFile(*module, optimized_filename);
+ }
+
+ // Finally, produce PTX.
+ return EmitModuleToPTX(module, target_machine.get());
+}
+
+// One-time module initializer.
+// Must be called only once -- DO NOT CALL DIRECTLY.
+void GPUBackendInit() {
+ // Feed all customized flags here, so we can override them with llvm_cl_opts
+ // without redeploy the compiler for development purpose.
+
+ // This flag tunes a threshold in branch folding. The default threshold, which
+ // is one, is not suitable for CUDA programs where branches are more expensive
+ // than for CPU programs. Setting the threshold to 2 improves the latency of
+ // TwoDPatchDotProductKernel_IND_3_ND_48 by over 5%, and does not affect the
+ // latency of other benchmarks so far.
+ //
+ // I also tried setting this threshold to other values:
+ // * 3-6 gives similar results as 2;
+ // * >6 start hurting the performance of at least dot product kernels.
+ //
+ // TODO(jingyue): The current threshold only considers the numbr of IR
+ // instructions which do not accurately reflect the true cost. We need a
+ // better cost model.
+ FeedLLVMWithFlags({"-bonus-inst-threshold=2"});
+ // TODO(b/22073864): Increase limit when scan memory dependency.
+ // This helps to reduce more redundant load instructions.
+ //
+ // The specific value is currently large enough for s3d in shoc benchmark,
+ // which contains a lot of load instructions and many arithmetic instructions
+ // between those loads.
+ FeedLLVMWithFlags({"-memdep-block-scan-limit=500"});
+
+ legacy_flags::GpuBackendLibFlags* flags =
+ legacy_flags::GetGpuBackendLibFlags();
+ if (!flags->llvm_cl_opts.empty()) {
+ std::vector<string> opts =
+ tensorflow::str_util::Split(flags->llvm_cl_opts, ',');
+ FeedLLVMWithFlags(opts);
+ }
+
+ if (flags->llvm_dump_passes) {
+ // Enable LLVM pass debugging dump. LLVM dumps this information when a pass
+ // manager is initialized for execution. It's done to stderr (this is
+ // hardcoded within LLVM to the dbgs() stream, we can't change it from the
+ // outside).
+ FeedLLVMWithFlags({"-debug-pass=Arguments"});
+ }
+
+ // Initialize the NVPTX target; it's the only target we link with, so call its
+ // specific initialization functions instead of the catch-all InitializeAll*.
+ LLVMInitializeNVPTXTarget();
+ LLVMInitializeNVPTXTargetInfo();
+ LLVMInitializeNVPTXTargetMC();
+ LLVMInitializeNVPTXAsmPrinter();
+
+ // Initialize the LLVM optimization passes.
+ llvm::PassRegistry* registry = llvm::PassRegistry::getPassRegistry();
+ InitializePasses(registry);
+}
+
+} // namespace
+
+StatusOr<string> CompileToPtx(llvm::Module* module,
+ const string& libdevice_dir_path) {
+ static std::once_flag backend_init_flag;
+ std::call_once(backend_init_flag, GPUBackendInit);
+
+ string ptx;
+ {
+ ScopedLoggingTimer compilation_timer(
+ "Compile module " + llvm_ir::AsString(module->getName()),
+ /*vlog_level=*/2);
+ TF_ASSIGN_OR_RETURN(ptx, CompileModuleToPtx(module, libdevice_dir_path));
+ }
+ return ptx;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
new file mode 100644
index 0000000000..3413f33301
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// LLVM-based compiler backend.
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_
+
+#include <string>
+
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace xla {
+namespace gpu {
+
+// The Compile.* interfaces each create their own llvm::LLVMContext objects for
+// thread safety, but note that LLVM's multithreaded support is very
+// preliminary; multithreaded use is not recommended at this time.
+//
+// Compiles the argument module and returns it. libdevice_dir_path is the parent
+// directory of the libdevice bitcode libraries. The contents of the module may
+// be changed.
+StatusOr<string> CompileToPtx(llvm::Module* module,
+ const string& libdevice_dir_path);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_GPU_BACKEND_LIB_H_
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll
new file mode 100644
index 0000000000..2ae1d2f7ea
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll
@@ -0,0 +1,141 @@
+target datalayout = "e-i64:64-v16:16-v32:32-n16:32:64"
+target triple = "nvptx64-unknown-unknown"
+
+%struct.uint3 = type { i32, i32, i32 }
+%struct.dim3 = type { i32, i32, i32 }
+
+@blockIdx = external addrspace(1) global %struct.uint3
+@blockDim = external addrspace(1) global %struct.dim3
+@threadIdx = external addrspace(1) global %struct.uint3
+
+; Function Attrs: alwaysinline nounwind readnone
+define float @expf(float %f) #0 {
+entry:
+ %f.addr = alloca float, align 4
+ store float %f, float* %f.addr, align 4
+ %0 = load float, float* %f.addr, align 4
+ %call = call float @__nv_expf(float %0)
+ ret float %call
+}
+
+declare float @__nv_expf(float) #1
+
+; Function Attrs: nounwind
+define void @cuda_saxpy(i32* %n, float* %a, float* %x, float* %y) #2 {
+entry:
+ %n.addr = alloca i32*, align 8
+ %a.addr = alloca float*, align 8
+ %x.addr = alloca float*, align 8
+ %y.addr = alloca float*, align 8
+ %i = alloca i32, align 4
+ store i32* %n, i32** %n.addr, align 8
+ store float* %a, float** %a.addr, align 8
+ store float* %x, float** %x.addr, align 8
+ store float* %y, float** %y.addr, align 8
+ %0 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @blockIdx to %struct.uint3*), i32 0, i32 0), align 4
+ %1 = load i32, i32* getelementptr inbounds (%struct.dim3, %struct.dim3* addrspacecast (%struct.dim3 addrspace(1)* @blockDim to %struct.dim3*), i32 0, i32 0), align 4
+ %mul = mul i32 %0, %1
+ %2 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @threadIdx to %struct.uint3*), i32 0, i32 0), align 4
+ %add = add i32 %mul, %2
+ store i32 %add, i32* %i, align 4
+ %3 = load i32, i32* %i, align 4
+ %4 = load i32*, i32** %n.addr, align 8
+ %arrayidx = getelementptr inbounds i32, i32* %4, i64 0
+ %5 = load i32, i32* %arrayidx, align 4
+ %cmp = icmp slt i32 %3, %5
+ br i1 %cmp, label %if.then, label %if.end
+
+if.then: ; preds = %entry
+ %6 = load float*, float** %a.addr, align 8
+ %arrayidx1 = getelementptr inbounds float, float* %6, i64 0
+ %7 = load float, float* %arrayidx1, align 4
+ %8 = load i32, i32* %i, align 4
+ %idxprom = sext i32 %8 to i64
+ %9 = load float*, float** %x.addr, align 8
+ %arrayidx2 = getelementptr inbounds float, float* %9, i64 %idxprom
+ %10 = load float, float* %arrayidx2, align 4
+ %mul3 = fmul float %7, %10
+ %11 = load i32, i32* %i, align 4
+ %idxprom4 = sext i32 %11 to i64
+ %12 = load float*, float** %y.addr, align 8
+ %arrayidx5 = getelementptr inbounds float, float* %12, i64 %idxprom4
+ %13 = load float, float* %arrayidx5, align 4
+ %add6 = fadd float %mul3, %13
+ %14 = load i32, i32* %i, align 4
+ %idxprom7 = sext i32 %14 to i64
+ %15 = load float*, float** %y.addr, align 8
+ %arrayidx8 = getelementptr inbounds float, float* %15, i64 %idxprom7
+ store float %add6, float* %arrayidx8, align 4
+ br label %if.end
+
+if.end: ; preds = %if.then, %entry
+ ret void
+}
+
+; Function Attrs: nounwind
+define void @cuda_saxpy_s(i32* %n, float* %a, float* %x, float* %y) #2 {
+entry:
+ %n.addr = alloca i32*, align 8
+ %a.addr = alloca float*, align 8
+ %x.addr = alloca float*, align 8
+ %y.addr = alloca float*, align 8
+ %i = alloca i32, align 4
+ store i32* %n, i32** %n.addr, align 8
+ store float* %a, float** %a.addr, align 8
+ store float* %x, float** %x.addr, align 8
+ store float* %y, float** %y.addr, align 8
+ %0 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @blockIdx to %struct.uint3*), i32 0, i32 0), align 4
+ %1 = load i32, i32* getelementptr inbounds (%struct.dim3, %struct.dim3* addrspacecast (%struct.dim3 addrspace(1)* @blockDim to %struct.dim3*), i32 0, i32 0), align 4
+ %mul = mul i32 %0, %1
+ %2 = load i32, i32* getelementptr inbounds (%struct.uint3, %struct.uint3* addrspacecast (%struct.uint3 addrspace(1)* @threadIdx to %struct.uint3*), i32 0, i32 0), align 4
+ %add = add i32 %mul, %2
+ store i32 %add, i32* %i, align 4
+ call void @llvm.cuda.syncthreads()
+ %3 = load i32, i32* %i, align 4
+ %4 = load i32*, i32** %n.addr, align 8
+ %arrayidx = getelementptr inbounds i32, i32* %4, i64 0
+ %5 = load i32, i32* %arrayidx, align 4
+ %cmp = icmp slt i32 %3, %5
+ br i1 %cmp, label %if.then, label %if.end
+
+if.then: ; preds = %entry
+ %6 = load float*, float** %a.addr, align 8
+ %arrayidx1 = getelementptr inbounds float, float* %6, i64 0
+ %7 = load float, float* %arrayidx1, align 4
+ %8 = load i32, i32* %i, align 4
+ %idxprom = sext i32 %8 to i64
+ %9 = load float*, float** %x.addr, align 8
+ %arrayidx2 = getelementptr inbounds float, float* %9, i64 %idxprom
+ %10 = load float, float* %arrayidx2, align 4
+ %mul3 = fmul float %7, %10
+ %11 = load i32, i32* %i, align 4
+ %idxprom4 = sext i32 %11 to i64
+ %12 = load float*, float** %y.addr, align 8
+ %arrayidx5 = getelementptr inbounds float, float* %12, i64 %idxprom4
+ %13 = load float, float* %arrayidx5, align 4
+ %add6 = fadd float %mul3, %13
+ %14 = load i32, i32* %i, align 4
+ %idxprom7 = sext i32 %14 to i64
+ %15 = load float*, float** %y.addr, align 8
+ %arrayidx8 = getelementptr inbounds float, float* %15, i64 %idxprom7
+ store float %add6, float* %arrayidx8, align 4
+ br label %if.end
+
+if.end: ; preds = %if.then, %entry
+ ret void
+}
+
+; Function Attrs: nounwind
+declare void @llvm.cuda.syncthreads() #3
+
+attributes #0 = { alwaysinline nounwind readnone "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #2 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-realign-stack" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
+
+!nvvm.annotations = !{!0, !1}
+!llvm.ident = !{!2}
+
+!0 = !{void (i32*, float*, float*, float*)* @cuda_saxpy, !"kernel", i32 1}
+!1 = !{void (i32*, float*, float*, float*)* @cuda_saxpy_s, !"kernel", i32 1}
+!2 = !{!"clang version xla-trunk (trunk r203011)"}
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
new file mode 100644
index 0000000000..c10346bbc2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.cc
@@ -0,0 +1,65 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IRReader/IRReader.h"
+#include "external/llvm/include/llvm/Support/SourceMgr.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace {
+
+static void DieWithSMDiagnosticError(llvm::SMDiagnostic* diagnostic) {
+ LOG(FATAL) << diagnostic->getLineNo() << ":" << diagnostic->getColumnNo()
+ << ": " << diagnostic->getMessage().str();
+}
+
+} // namespace
+
+namespace xla {
+namespace gpu {
+
+std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
+ llvm::LLVMContext* llvm_context) {
+ llvm::SMDiagnostic diagnostic_err;
+ std::unique_ptr<llvm::Module> module(
+ llvm::parseIRFile(llvm::StringRef(filename.data(), filename.size()),
+ diagnostic_err, *llvm_context));
+
+ if (module == nullptr) {
+ DieWithSMDiagnosticError(&diagnostic_err);
+ }
+
+ return module;
+}
+
+string ReplaceFilenameExtension(tensorflow::StringPiece filename,
+ tensorflow::StringPiece new_extension) {
+ auto pos = filename.rfind('.');
+ tensorflow::StringPiece stem =
+ pos == tensorflow::StringPiece::npos
+ ? filename
+ : tensorflow::StringPiece(filename.data(), pos);
+ return tensorflow::strings::StrCat(stem, ".", new_extension);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
new file mode 100644
index 0000000000..a6daeca95a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h
@@ -0,0 +1,50 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_
+
+#include <memory>
+#include <string>
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace llvm {
+class LLVMContext;
+class Module;
+}
+
+namespace xla {
+namespace gpu {
+
+// Convenience function for loading a LLVM module from an IR file. The module
+// is created in the given LLVM context.
+//
+// If loading fails for some reason, dies printing a diagnostic error.
+std::unique_ptr<llvm::Module> LoadIRModule(const string& filename,
+ llvm::LLVMContext* llvm_context);
+
+// Convenience function for replacing the extension of the given filename.
+// If the filename has no extension, the new extension is appended to its name.
+//
+// For example:
+// ReplaceFilenameExtension("/foo/baz.txt", "cc") --> "/foo/baz.cc"
+string ReplaceFilenameExtension(tensorflow::StringPiece filename,
+ tensorflow::StringPiece new_extension);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LLVM_GPU_BACKEND_UTILS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc
new file mode 100644
index 0000000000..3848e58b0d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h"
+
+#include <string>
+
+#include "tensorflow/core/lib/io/path.h"
+
+#include "external/llvm/include/llvm/IR/LLVMContext.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace gpu {
+namespace {
+
+const char kSaxpyIRFile[] =
+ "compiler/xla/service/gpu/llvm_gpu_backend/tests_data/saxpy.ll";
+
+TEST(UtilsTest, TestLoadIRModule) {
+ llvm::LLVMContext llvm_context;
+ string test_srcdir = tensorflow::testing::TensorFlowSrcRoot();
+ std::unique_ptr<llvm::Module> module = LoadIRModule(
+ tensorflow::io::JoinPath(test_srcdir, kSaxpyIRFile), &llvm_context);
+ // Sanity check that the module was loaded properly.
+ ASSERT_NE(nullptr, module);
+ ASSERT_NE(std::string::npos, module->getModuleIdentifier().find("saxpy.ll"));
+ ASSERT_NE(nullptr, module->getFunction("cuda_saxpy"));
+}
+
+TEST(UtilsTest, TestReplaceFilenameExtension) {
+ ASSERT_EQ(ReplaceFilenameExtension("baz.tx", "cc"), "baz.cc");
+ ASSERT_EQ(ReplaceFilenameExtension("/foo/baz.txt", "cc"), "/foo/baz.cc");
+ ASSERT_EQ(ReplaceFilenameExtension("/foo/baz.", "-nvptx.dummy"),
+ "/foo/baz.-nvptx.dummy");
+ ASSERT_EQ(ReplaceFilenameExtension("/foo/baz", "cc"), "/foo/baz.cc");
+}
+
+} // namespace
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
new file mode 100644
index 0000000000..ca70d55fab
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -0,0 +1,408 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/pad_insertion.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+bool IsForwardConvolutionCanonical(const HloInstruction& conv) {
+ CHECK_EQ(HloOpcode::kConvolution, conv.opcode());
+ return window_util::HasEvenPadding(conv.window()) &&
+ !window_util::HasNegativePadding(conv.window()) &&
+ !window_util::HasDilation(conv.window());
+}
+
+// If the (positive and negative) padding on the input operand of a convolution
+// can't be folded into a cuDNN convolution libcall (e.g. uneven padding and
+// dilation), returns kPad and/or kSlice instructions that explicitly apply the
+// padding; otherwise returns the original input operand. When there is both
+// positive padding (including dilation) and negative padding, we insert both
+// kPad and kSlice.
+HloInstruction* MaybePaddedAndSlicedInput(
+ const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums,
+ HloInstruction* input) {
+ HloComputation* computation = input->parent();
+ if (!window_util::HasEvenPadding(conv_window) ||
+ window_util::HasBaseDilation(conv_window)) {
+ // If padding is uneven or has dilation, we insert a kPad instruction that
+ // applies positive padding and dilation.
+ PaddingConfig padding_config =
+ MakeNoPaddingConfig(input->shape().dimensions_size());
+ for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
+ int64 dim = conv_dnums.spatial_dimensions(i);
+ padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+ std::max<int64>(0LL, conv_window.dimensions(i).padding_low()));
+ padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+ std::max<int64>(0LL, conv_window.dimensions(i).padding_high()));
+ padding_config.mutable_dimensions(dim)->set_interior_padding(
+ conv_window.dimensions(i).base_dilation() - 1);
+ }
+ PrimitiveType element_type = input->shape().element_type();
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ input = computation->AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(
+ /*operand_shape=*/input->shape(),
+ /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
+ padding_config)
+ .ConsumeValueOrDie(),
+ input, padding, padding_config));
+ }
+
+ if (window_util::HasNegativePadding(conv_window)) {
+ // If the window has negative padding, insert a kSlice that explicitly
+ // applies negative padding.
+ //
+ // For each dimension, initialize the start index to 0 and the limit index
+ // to the size of that dimension.
+ std::vector<int64> start_indices(input->shape().dimensions_size(), 0);
+ std::vector<int64> limit_indices(input->shape().dimensions().begin(),
+ input->shape().dimensions().end());
+ for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
+ int64 dim = conv_dnums.spatial_dimensions(i);
+ // If dimension "dim" has negative padding, increase the start index or
+ // decrement the limit index by the amount of negative padding.
+ start_indices[dim] +=
+ std::max<int64>(0LL, -conv_window.dimensions(i).padding_low());
+ limit_indices[dim] -=
+ std::max<int64>(0LL, -conv_window.dimensions(i).padding_high());
+ }
+
+ input = computation->AddInstruction(HloInstruction::CreateSlice(
+ ShapeInference::InferSliceShape(input->shape(), start_indices,
+ limit_indices)
+ .ConsumeValueOrDie(),
+ input, start_indices, limit_indices));
+ }
+
+ return input;
+}
+
+// If the padding on the kernel operand of a convolution can't be folded into a
+// cuDNN convolution libcall (e.g. dilation), returns a kPad instruction that
+// explicitly applies the padding; otherwise returns the original kernel
+// operand.
+HloInstruction* MaybePaddedKernel(const Window& conv_window,
+ const ConvolutionDimensionNumbers& conv_dnums,
+ HloInstruction* kernel) {
+ if (!window_util::HasWindowDilation(conv_window)) {
+ return kernel;
+ }
+
+ // Compute the shape and padding config of the pad to be inserted.
+ PaddingConfig padding_config;
+ for (size_t i = 0; i < kernel->shape().dimensions_size(); ++i) {
+ padding_config.add_dimensions();
+ }
+ for (size_t i = 0; i < conv_dnums.spatial_dimensions().size(); ++i) {
+ int64 dim = conv_dnums.spatial_dimensions(i);
+ padding_config.mutable_dimensions(dim)->set_interior_padding(
+ conv_window.dimensions(i).window_dilation() - 1);
+ }
+
+ HloComputation* computation = kernel->parent();
+ PrimitiveType element_type = kernel->shape().element_type();
+ HloInstruction* padding =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ MakeUnique<Literal>(LiteralUtil::Zero(element_type))));
+ return computation->AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(
+ /*operand_shape=*/kernel->shape(),
+ /*padding_value_shape=*/ShapeUtil::MakeShape(element_type, {}),
+ padding_config)
+ .ConsumeValueOrDie(),
+ kernel, padding, padding_config));
+}
+} // namespace
+
+bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) {
+ if (IsForwardConvolutionCanonical(*conv)) {
+ return false;
+ }
+
+ // Insert slices and/or pads between the convolution and its input and/or
+ // kernel operand.
+ HloInstruction* new_input = MaybePaddedAndSlicedInput(
+ conv->window(), conv->convolution_dimension_numbers(),
+ conv->mutable_operand(0));
+ HloInstruction* new_kernel =
+ MaybePaddedKernel(conv->window(), conv->convolution_dimension_numbers(),
+ conv->mutable_operand(1));
+
+ // Remove the padding from convolution's window field. These paddings are
+ // made explicit with the inserted pads.
+ Window new_conv_window = conv->window();
+ for (size_t i = 0; i < new_conv_window.dimensions_size(); ++i) {
+ WindowDimension* dim = new_conv_window.mutable_dimensions(i);
+ dim->set_padding_low(0);
+ dim->set_padding_high(0);
+ dim->set_base_dilation(1);
+ dim->set_window_dilation(1);
+ }
+ conv->parent()->ReplaceWithNewInstruction(
+ conv, HloInstruction::CreateConvolve(
+ conv->shape(), new_input, new_kernel, new_conv_window,
+ conv->convolution_dimension_numbers()));
+ return true;
+}
+
+namespace {
+void IncreasePaddingLowBy(int64 delta, WindowDimension* window_dim) {
+ window_dim->set_padding_low(window_dim->padding_low() + delta);
+}
+
+void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) {
+ window_dim->set_padding_high(window_dim->padding_high() + delta);
+}
+} // namespace
+
+bool PadInsertion::CanonicalizeBackwardFilterConvolution(
+ HloInstruction* backward_conv) {
+ if (window_util::HasEvenPadding(backward_conv->window())) {
+ return false;
+ }
+
+ // A backward filter convolution with uneven padding can be canonicalized to
+ // one with even padding by padding the activations (input) beforehand. For
+ // example,
+ // BackwardFilterConv(ABCD, xyz, padding_low=1, padding_high=2)
+ // is equivalent to
+ // ABCD0 = Pad(ABCD, padding_high=1)
+ // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
+ // We choose the lesser of padding_low and padding_high as the new padding.
+ HloInstruction* transpose = backward_conv->fused_expression_root();
+ HloInstruction* forward_conv = transpose->mutable_operand(0);
+ HloInstruction* input = backward_conv->mutable_operand(0);
+ Window new_forward_conv_window = forward_conv->window();
+ Window new_backward_conv_window = backward_conv->window();
+ // input_padding_config is the config of the kPad to be inserted.
+ PaddingConfig input_padding_config =
+ MakeNoPaddingConfig(ShapeUtil::Rank(input->shape()));
+ ConvolutionDimensionNumbers forward_conv_dnums =
+ forward_conv->convolution_dimension_numbers();
+ ConvolutionDimensionNumbers backward_conv_dnums =
+ backward_conv->convolution_dimension_numbers();
+ for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+ int64 padding_low = backward_conv->window().dimensions(i).padding_low();
+ int64 padding_high = backward_conv->window().dimensions(i).padding_high();
+ if (padding_low < 0 || padding_high < 0) {
+ // TODO(b/32744257): The following canonicalization wouldn't remove
+ // negative padding in a backward convolution, and would therefore cause
+ // cuDNN convolution (which doesn't support negative padding) to fail.
+ return false;
+ }
+ // If the backward convolution has uneven padding on the activations, we
+ // move some padding on the larger end to "internal" padding, so that the
+ // backward convolution produces larger weight gradients which get sliced
+ // later. Therefore, the amount of new padding (low or high) is the minimum
+ // of the amount of old padding low and old padding high.
+ int64 new_conv_padding = std::min(padding_low, padding_high);
+ int64 dim = backward_conv_dnums.spatial_dimensions(i);
+ input_padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+ padding_low - new_conv_padding);
+ input_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+ padding_high - new_conv_padding);
+
+ // Since we move some padding from the backward convolution to the kPad, we
+ // need to accordingly reduce the padding amount of the backward convolution
+ // and its inner forward convolution.
+ IncreasePaddingLowBy(-(padding_low - new_conv_padding),
+ new_backward_conv_window.mutable_dimensions(i));
+ IncreasePaddingHighBy(-(padding_high - new_conv_padding),
+ new_backward_conv_window.mutable_dimensions(i));
+ IncreasePaddingLowBy(-(padding_low - new_conv_padding),
+ new_forward_conv_window.mutable_dimensions(i));
+ IncreasePaddingHighBy(-(padding_high - new_conv_padding),
+ new_forward_conv_window.mutable_dimensions(i));
+ }
+
+ // Create a new backward convolution replacing the old one.
+ HloComputation* computation = backward_conv->parent();
+ HloInstruction* output = backward_conv->mutable_operand(1);
+ HloInstruction* padding = computation->AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(
+ LiteralUtil::Zero(input->shape().element_type()))));
+ HloInstruction* padded_input =
+ computation->AddInstruction(HloInstruction::CreatePad(
+ ShapeInference::InferPadShape(input->shape(), padding->shape(),
+ input_padding_config)
+ .ConsumeValueOrDie(),
+ input, padding, input_padding_config));
+
+ HloInstruction* new_forward_conv =
+ computation->AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(
+ padded_input->shape(), output->shape(), new_forward_conv_window,
+ forward_conv_dnums)
+ .ConsumeValueOrDie(),
+ padded_input, output, new_forward_conv_window, forward_conv_dnums));
+
+ HloInstruction* new_transpose =
+ computation->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeInference::InferTransposeShape(new_forward_conv->shape(),
+ transpose->dimensions())
+ .ConsumeValueOrDie(),
+ new_forward_conv, transpose->dimensions()));
+
+ // Fuse the new forward convolution and the new transpose to the new backward
+ // convolution.
+ HloInstruction* new_backward_conv =
+ computation->CreateFusionInstructionForBackwardConvolution(
+ {new_transpose, new_forward_conv},
+ HloInstruction::FusionKind::kConvBackwardFilter,
+ new_backward_conv_window, backward_conv_dnums);
+ computation->ReplaceInstruction(backward_conv, new_backward_conv);
+ return true;
+}
+
+bool PadInsertion::CanonicalizeBackwardInputConvolution(
+ HloInstruction* backward_conv) {
+ if (window_util::HasEvenPadding(backward_conv->window())) {
+ return false;
+ }
+
+ HloInstruction* forward_conv = backward_conv->fused_expression_root();
+ HloInstruction* reverse_filter = forward_conv->mutable_operand(1);
+ Window new_forward_conv_window = forward_conv->window();
+ Window new_backward_conv_window = backward_conv->window();
+ ConvolutionDimensionNumbers forward_conv_dnums =
+ forward_conv->convolution_dimension_numbers();
+ ConvolutionDimensionNumbers backward_conv_dnums =
+ backward_conv->convolution_dimension_numbers();
+ for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+ int64 padding_low = backward_conv->window().dimensions(i).padding_low();
+ int64 padding_high = backward_conv->window().dimensions(i).padding_high();
+ if (padding_low < 0 || padding_high < 0) {
+ // TODO(b/32744257): The following canonicalization wouldn't remove
+ // negative padding in a backward convolution, and would therefore cause
+ // cuDNN convolution (which doesn't support negative padding) to fail.
+ return false;
+ }
+ // If the backward convolution has uneven padding on the activations, we
+ // move some padding on the larger end to "internal" padding, so that the
+ // backward convolution produces larger activations which get sliced later.
+ //
+ // For example, suppose we have a non-canonical HLO
+ // [A] = BackwardInputConvolve([a b], [x y z], padding=(low=2,high=1))
+ // where the amount of padding low is larger, we can canonicalize it to
+ // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1))
+ // [A] = Slice([B A])
+ // For consistency, we need to increase the low padding of the inner
+ // convolution by 1 as well because the input is larger now.
+ if (padding_low > padding_high) {
+ IncreasePaddingLowBy(padding_high - padding_low,
+ new_backward_conv_window.mutable_dimensions(i));
+ IncreasePaddingLowBy(padding_low - padding_high,
+ new_forward_conv_window.mutable_dimensions(i));
+ } else if (padding_low < padding_high) {
+ IncreasePaddingHighBy(padding_low - padding_high,
+ new_backward_conv_window.mutable_dimensions(i));
+ IncreasePaddingHighBy(padding_high - padding_low,
+ new_forward_conv_window.mutable_dimensions(i));
+ }
+ }
+
+ // Create a new backward convolution replacing the old one.
+ HloComputation* computation = backward_conv->parent();
+ HloInstruction* output = backward_conv->mutable_operand(0);
+ HloInstruction* filter = backward_conv->mutable_operand(1);
+ HloInstruction* new_reverse_filter =
+ computation->AddInstruction(HloInstruction::CreateReverse(
+ filter->shape(), filter, reverse_filter->dimensions()));
+ HloInstruction* new_forward_conv =
+ computation->AddInstruction(HloInstruction::CreateConvolve(
+ ShapeInference::InferConvolveShape(
+ output->shape(), new_reverse_filter->shape(),
+ new_forward_conv_window, forward_conv_dnums)
+ .ConsumeValueOrDie(),
+ output, new_reverse_filter, new_forward_conv_window,
+ forward_conv_dnums));
+ HloInstruction* new_backward_conv =
+ computation->CreateFusionInstructionForBackwardConvolution(
+ {new_forward_conv, new_reverse_filter},
+ HloInstruction::FusionKind::kConvBackwardInput,
+ new_backward_conv_window, backward_conv_dnums);
+
+ // Slice the new backward convolution.
+ //
+ // Initialize start_indices and limit_indices as no slicing.
+ std::vector<int64> start_indices(new_backward_conv->shape().dimensions_size(),
+ 0LL);
+ std::vector<int64> limit_indices(
+ new_backward_conv->shape().dimensions().begin(),
+ new_backward_conv->shape().dimensions().end());
+ for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) {
+ int64 padding_low = backward_conv->window().dimensions(i).padding_low();
+ int64 padding_high = backward_conv->window().dimensions(i).padding_high();
+ int64 dim = backward_conv_dnums.spatial_dimensions(i);
+ if (padding_low > padding_high) {
+ // If the amount of low padding (of the old backward convolution) is
+ // larger, we internally pad the low end of the activations and slice
+ // internal padding out here.
+ start_indices[dim] += padding_low - padding_high;
+ } else if (padding_low < padding_high) {
+ // If the amount of high padding is larger, we slice out the internal
+ // padding on the high end.
+ limit_indices[dim] -= padding_high - padding_low;
+ }
+ }
+
+ // Replace the old backward convolution with the slice.
+ CHECK(ShapeUtil::Compatible(
+ ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices,
+ limit_indices)
+ .ConsumeValueOrDie(),
+ backward_conv->shape()));
+ computation->ReplaceWithNewInstruction(
+ backward_conv,
+ HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv,
+ start_indices, limit_indices));
+ return true;
+}
+
+StatusOr<bool> PadInsertion::Run(HloModule* module) {
+ bool changed = false;
+ for (HloInstruction* instruction :
+ module->entry_computation()->MakeInstructionPostOrder()) {
+ if (instruction->opcode() == HloOpcode::kConvolution) {
+ changed |= CanonicalizeForwardConvolution(instruction);
+ } else if (instruction->opcode() == HloOpcode::kFusion) {
+ switch (instruction->fusion_kind()) {
+ case HloInstruction::FusionKind::kConvBackwardFilter:
+ changed |= CanonicalizeBackwardFilterConvolution(instruction);
+ break;
+ case HloInstruction::FusionKind::kConvBackwardInput:
+ changed |= CanonicalizeBackwardInputConvolution(instruction);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.h b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
new file mode 100644
index 0000000000..ec517ed8e6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+namespace gpu {
+
+// An HLO pass that canonicalizes convolution instructions for GPU codegen. It
+// inserts Pad instructions before Convolution instructions with uncanonicalized
+// padding, so that they can be lowered to cuDNN convolution.
+class PadInsertion : public HloPass {
+ public:
+ PadInsertion() : HloPass("pad insertion") {}
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ // Returns if any changes are made to the parent computation.
+ bool CanonicalizeForwardConvolution(HloInstruction* conv);
+ bool CanonicalizeBackwardFilterConvolution(HloInstruction* backward_conv);
+ bool CanonicalizeBackwardInputConvolution(HloInstruction* backward_conv);
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PAD_INSERTION_H_
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
new file mode 100644
index 0000000000..65610b0995
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
+
+#include <memory>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
+#include "external/llvm/include/llvm/IR/Intrinsics.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace gpu {
+
+ParallelLoopEmitter::ParallelLoopEmitter(
+ BodyEmitter body_emitter, const Shape& shape,
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ : LoopEmitter(body_emitter, shape, ir_builder),
+ launch_dimensions_(launch_dimensions) {}
+
+ParallelLoopEmitter::ParallelLoopEmitter(
+ const llvm_ir::ElementGenerator& target_element_generator,
+ const llvm_ir::IrArray& target_array,
+ const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder)
+ : LoopEmitter(target_element_generator, target_array, ir_builder),
+ launch_dimensions_(launch_dimensions) {}
+
+llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock() {
+ // Emit the following code in LLVM IR:
+ // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
+ // if (linear_index < num_elements) {
+ // array_index = LinearIndexToMultidimensionalIndex(shape_, linear_index);
+ // ...
+ // }
+
+ // Per the PTX documentation:
+ // "It is guaranteed that [...] 0 <= %ctaid.x < %nctaid.x"
+ //
+ // %nctaid.x is currently specified as 2147483647.
+ llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, ir_builder_);
+ llvm_ir::AddRangeMetadata(0, launch_dimensions_.block_count(),
+ static_cast<llvm::Instruction*>(block_id));
+ block_id =
+ ir_builder_->CreateZExt(block_id, ir_builder_->getInt64Ty(), "block_id");
+
+ // Per the PTX documentation:
+ // "It is guaranteed that [...] 0 <= %tid.x < %ntid.x"
+ //
+ // %ntid.x is currently specified as 1024.
+ llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, ir_builder_);
+ llvm_ir::AddRangeMetadata(0, launch_dimensions_.threads_per_block(),
+ static_cast<llvm::Instruction*>(thread_id));
+ thread_id = ir_builder_->CreateZExt(thread_id, ir_builder_->getInt64Ty(),
+ "thread_id");
+
+ llvm::Value* linear_index = ir_builder_->CreateAdd(
+ ir_builder_->CreateMul(
+ block_id,
+ ir_builder_->getInt64(launch_dimensions_.threads_per_block()), "",
+ /*HasNUW=*/true, /*HasNSW=*/true),
+ thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true);
+
+ auto if_in_bounds = llvm_ir::EmitIfThenElse(
+ ir_builder_->CreateICmpULT(
+ linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))),
+ "in_bounds", ir_builder_, false);
+
+ // Set exit_bb_ to the exit block of the if structure.
+ exit_bb_ = if_in_bounds.after_block;
+ CHECK_NE(nullptr, exit_bb_);
+
+ // Set IR builder insertion point to the body of the if structure.
+ llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, ir_builder_);
+ return llvm_ir::IrArray::Index(linear_index, shape_, ir_builder_);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
new file mode 100644
index 0000000000..73ca28cd84
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+
+namespace xla {
+namespace gpu {
+
+// Emits a parallel loop for every element in the given array shape. This loop
+// emitted will be executed by multiple threads in parallel. Therefore, each
+// thread instance of the loop iterates over part of the array, and they
+// collectively iterates over the entire array.
+class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
+ public:
+ // `thread_count` is the number of threads to parallelize the loop on.
+ // The meanings of other parameters are the same as LoopEmitter.
+ ParallelLoopEmitter(BodyEmitter body_emitter, const Shape& shape,
+ const LaunchDimensions& launch_dimensions,
+ llvm::IRBuilder<>* ir_builder);
+ // Constructs a ParallelLoopEmitter from an element generator that generates
+ // each element of the given target array.
+ ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
+ const llvm_ir::IrArray& target_array,
+ const LaunchDimensions& launch_dimensions,
+ llvm::IRBuilder<>* ir_builder);
+ ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
+ ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
+ ~ParallelLoopEmitter() override = default;
+
+ llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock() override;
+
+ private:
+ // The thread and block dimension to parallelize the loop on.
+ const LaunchDimensions launch_dimensions_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARALLEL_LOOP_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
new file mode 100644
index 0000000000..d0d2deee24
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -0,0 +1,99 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
+
+#include <ostream>
+#include <string>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/bits.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+std::ostream& operator<<(std::ostream& out,
+ const LaunchDimensions& launch_dims) {
+ out << tensorflow::strings::Printf("[block: %lld, thread: %lld]",
+ launch_dims.block_count(),
+ launch_dims.threads_per_block());
+ return out;
+}
+
+// Calculates the launch dimensions used to invoke `hlo`.
+LaunchDimensions CalculateLaunchDimensions(
+ const Shape& shape, const se::DeviceDescription& device_desc,
+ PartitionStrategy partition_strategy) {
+ int64 warp_size = device_desc.threads_per_warp();
+
+ int64 num_elements = ShapeUtil::ElementsIn(shape);
+ if (num_elements <= 1) {
+ return LaunchDimensions();
+ }
+
+ // Calculate the number of threads per block.
+ // Initialize threads_per_block as the threads-per-block limit.
+ int64 threads_per_block = device_desc.threads_per_block_limit();
+ VLOG(2) << "Initial # of threads per block = " << threads_per_block;
+
+ if (partition_strategy == PartitionStrategy::kLatency) {
+ // Limit the thread count to allow maximum number of registers per thread.
+ // TODO(b/28560520): We don't have to assume the emitted kernel will use up
+ // all the registers. We could use ptxas to examine the actual number of
+ // register used, and set the thread count accordingly.
+ int64 threads_per_block_limit_due_to_registers =
+ device_desc.registers_per_core_limit() /
+ device_desc.registers_per_thread_limit();
+ CHECK_NE(0, threads_per_block_limit_due_to_registers);
+ if (threads_per_block_limit_due_to_registers < threads_per_block) {
+ threads_per_block =
+ // Make `threads_per_block` a multiple of warp size to use GPU
+ // efficiently.
+ warp_size *
+ std::max(1LL, threads_per_block_limit_due_to_registers / warp_size);
+ VLOG(2) << "Update # of threads per block due to register pressure = "
+ << threads_per_block;
+ }
+ }
+
+ if (num_elements < threads_per_block) {
+ threads_per_block = num_elements;
+ VLOG(2) << "Update # of threads per block to the element count ("
+ << threads_per_block << ") because the latter is smaller.";
+ }
+
+ // Calculate the block count. We copy the strategy used by Eigen:
+ // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h
+ int64 block_count = CeilOfRatio(num_elements, threads_per_block);
+ VLOG(2) << tensorflow::strings::Printf(
+ "Initialized the block count to ceil(# of elements / threads per "
+ "block) = ceil(%lld/%lld) = %lld",
+ num_elements, threads_per_block, block_count);
+
+ return LaunchDimensions(block_count, threads_per_block);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
new file mode 100644
index 0000000000..8ac4c59966
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Algorithms and data structures for partition assignment.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARTITION_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARTITION_ASSIGNMENT_H_
+
+#include <iosfwd>
+#include <map>
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+enum class PartitionStrategy {
+ // Optimized for latency by allowing maximum number of registers per thread.
+ kLatency,
+ // Optimized for throughtput. This may limit registers per thread and cause
+ // longer latency.
+ kThroughput
+};
+
+// Encapsulates the launch dimensions of a kernel, e.g., the block count and the
+// number of threads per block.
+class LaunchDimensions {
+ public:
+ // The default constructor creates a launch dimension that indicate
+ // single-threaded execution.
+ LaunchDimensions() : block_count_(1), threads_per_block_(1) {}
+
+ LaunchDimensions(int64 block_count, int64 threads_per_block)
+ : block_count_(block_count), threads_per_block_(threads_per_block) {}
+
+ bool IsSinglethreaded() const {
+ return block_count_ == 1 && threads_per_block_ == 1;
+ }
+
+ int64 block_count() const { return block_count_; }
+ int64 threads_per_block() const { return threads_per_block_; }
+
+ private:
+ int64 block_count_;
+ int64 threads_per_block_;
+};
+
+std::ostream& operator<<(std::ostream& out,
+ const LaunchDimensions& launch_dims);
+
+LaunchDimensions CalculateLaunchDimensions(
+ const Shape& shape,
+ const perftools::gputools::DeviceDescription& device_desc,
+ PartitionStrategy partition_strategy = PartitionStrategy::kLatency);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_PARTITION_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
new file mode 100644
index 0000000000..d8a43091d4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.cc
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+SequentialThunk::SequentialThunk(std::vector<std::unique_ptr<Thunk>>&& thunks,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kSequential, hlo), thunks_(std::move(thunks)) {}
+
+tensorflow::Status SequentialThunk::Initialize(
+ const GpuExecutable& executable) {
+ for (auto& thunk : thunks_) {
+ TF_RETURN_IF_ERROR(thunk->Initialize(executable));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status SequentialThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+ for (const auto& thunk : thunks_) {
+ TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream));
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/sequential_thunk.h b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
new file mode 100644
index 0000000000..32c5b748ab
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/sequential_thunk.h
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A thunk that wraps a list of sub-thunks. Executing this thunk executes all
+// the sub-thunks sequentially. This is useful to implement instructions that
+// require multiple kernel launches or library calls.
+class SequentialThunk : public Thunk {
+ public:
+ SequentialThunk(std::vector<std::unique_ptr<Thunk>>&& thunks,
+ const HloInstruction* hlo);
+ SequentialThunk(const SequentialThunk&) = delete;
+ SequentialThunk& operator=(const SequentialThunk&) = delete;
+
+ const std::vector<std::unique_ptr<Thunk>>& thunks() const { return thunks_; }
+
+ tensorflow::Status Initialize(const GpuExecutable& executable) override;
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ // The list of sub-thunks.
+ std::vector<std::unique_ptr<Thunk>> thunks_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_SEQUENTIAL_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
new file mode 100644
index 0000000000..5065e7aedd
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+
+#include "tensorflow/compiler/xla/legacy_flags/stream_assignment_flags.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+
+namespace xla {
+namespace gpu {
+
+bool StreamAssignment::HasStreamAssigned(const HloInstruction& hlo) const {
+ return hlo_to_stream_number_.count(&hlo);
+}
+
+int StreamAssignment::StreamNumberForHlo(const HloInstruction& hlo) const {
+ return FindOrDie(hlo_to_stream_number_, &hlo);
+}
+
+void StreamAssignment::AssignStreamToHlo(const HloInstruction* hlo,
+ int stream_no) {
+ CHECK_GE(stream_no, 0);
+ if (stream_no >= stream_count_) {
+ stream_count_ = stream_no + 1;
+ }
+ InsertOrDie(&hlo_to_stream_number_, hlo, stream_no);
+ VLOG(2) << "Assign stream #" << stream_no << " to " << hlo->ToString();
+}
+
+namespace {
+
+// Returns whether the two HLOs can run concurrently, i.e., neither is a
+// transitive consumer of the other.
+bool CanRunConcurrently(
+ const HloInstruction& a, const HloInstruction& b,
+ const HloComputation::ReachabilityMap& transitive_operands) {
+ return !transitive_operands.IsConnected(&a, &b);
+}
+
+// Returns which existing stream to assign to `hlo`, or -1 if a stream is not
+// needed. `stream_assignment` is the existing stream assignment for all
+// instructions topologically before `hlo`. `seen_gemms` contains all GEMMs that
+// are topologically before `hlo`.
+int ComputeStreamToAssign(
+ const HloInstruction& hlo, const StreamAssignment& stream_assignment,
+ const HloComputation::ReachabilityMap& transitive_operands,
+ const std::vector<const HloInstruction*>& seen_gemms) {
+ if (hlo.opcode() == HloOpcode::kParameter ||
+ hlo.opcode() == HloOpcode::kConstant) {
+ // kParameter and kConstant do not need a thunk.
+ return -1;
+ }
+
+ legacy_flags::StreamAssignmentFlags* flags =
+ legacy_flags::GetStreamAssignmentFlags();
+ if (flags->xla_gpu_disable_multi_streaming) {
+ return 0;
+ }
+
+ if (!ImplementedAsGemm(hlo)) {
+ // If `hlo` is not implemented as a GEMM, keep it close to its operands to
+ // avoid excessive synchronization.
+ int stream_no = -1;
+ for (const auto* operand : hlo.operands()) {
+ if (stream_assignment.HasStreamAssigned(*operand)) {
+ stream_no =
+ std::max(stream_no, stream_assignment.StreamNumberForHlo(*operand));
+ }
+ }
+ if (stream_no == -1) {
+ stream_no = 0;
+ }
+ return stream_no;
+ }
+
+ // Assign different streams to concurrent GEMMs. The code below uses a
+ // greedy approach. First, we compute as forbidden_stream_numbers the
+ // streams assigned to GEMMs that are concurrent with `hlo`. Then, we assign
+ // `hlo` a different stream.
+ std::set<int> forbidden_stream_numbers;
+ for (const auto* seen_gemm : seen_gemms) {
+ int stream_no = stream_assignment.StreamNumberForHlo(*seen_gemm);
+ if (!forbidden_stream_numbers.count(stream_no) &&
+ CanRunConcurrently(*seen_gemm, hlo, transitive_operands)) {
+ forbidden_stream_numbers.insert(stream_no);
+ }
+ }
+
+ for (int stream_no = 0; stream_no < stream_assignment.StreamCount();
+ ++stream_no) {
+ if (!forbidden_stream_numbers.count(stream_no)) {
+ return stream_no;
+ }
+ }
+ return stream_assignment.StreamCount();
+}
+
+} // namespace
+
+std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module) {
+ auto stream_assignment = MakeUnique<StreamAssignment>();
+ const HloComputation& computation = *module.entry_computation();
+ std::unique_ptr<HloComputation::ReachabilityMap> transitive_operands =
+ computation.ComputeTransitiveOperands();
+ std::vector<const HloInstruction*> seen_gemms;
+ for (const auto* hlo : computation.MakeInstructionPostOrder()) {
+ int stream_no = ComputeStreamToAssign(*hlo, *stream_assignment,
+ *transitive_operands, seen_gemms);
+ if (stream_no != -1) {
+ stream_assignment->AssignStreamToHlo(hlo, stream_no);
+ }
+ if (ImplementedAsGemm(*hlo)) {
+ seen_gemms.push_back(hlo);
+ }
+ }
+ return stream_assignment;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h
new file mode 100644
index 0000000000..c2df83aaa4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace xla {
+namespace gpu {
+
+// This class encapsulates the assignment of GPU streams to each HloInstruction.
+class StreamAssignment {
+ public:
+ int StreamCount() const { return stream_count_; }
+ int StreamNumberForHlo(const HloInstruction& hlo) const;
+ bool HasStreamAssigned(const HloInstruction& hlo) const;
+ // `hlo` needs to outlive this StreamAssignment object.
+ void AssignStreamToHlo(const HloInstruction* hlo, int stream_no);
+
+ private:
+ int stream_count_ = 1; // At least the main stream.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> hlo_to_stream_number_;
+};
+
+// Assigns GPU streams to instructions in `module`.
+std::unique_ptr<StreamAssignment> AssignStreams(const HloModule& module);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
new file mode 100644
index 0000000000..28d47d2b0f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment_test.cc
@@ -0,0 +1,132 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace xla {
+namespace gpu {
+
+class StreamAssignmentTest : public HloTestBase {
+ protected:
+ // Pre-canned shapes.
+ Shape f32_2x2_ = ShapeUtil::MakeShape(F32, {2, 2});
+};
+
+TEST_F(StreamAssignmentTest, SequentialMatMul) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* z = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/2, f32_2x2_, /*name=*/"z"));
+ HloInstruction* dot1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction* dot2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, dot1, z));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(dot2));
+
+ std::unique_ptr<StreamAssignment> assignment = AssignStreams(module);
+ EXPECT_EQ(assignment->StreamNumberForHlo(*dot1),
+ assignment->StreamNumberForHlo(*dot2));
+}
+
+TEST_F(StreamAssignmentTest, ConcurrentMatMul) {
+ HloComputation::Builder builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, f32_2x2_, /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, f32_2x2_, /*name=*/"y"));
+ HloInstruction* dot1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, x, y));
+ HloInstruction* dot2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, y, x));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kAdd, dot1, dot2));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(add));
+
+ std::unique_ptr<StreamAssignment> assignment = AssignStreams(module);
+ EXPECT_NE(assignment->StreamNumberForHlo(*dot1),
+ assignment->StreamNumberForHlo(*dot2));
+}
+
+TEST_F(StreamAssignmentTest, LatticeMatMul) {
+ // d00 -- layer 0
+ // / \
+ // d10 d11 -- layer 1
+ // / \ / \
+ // d20 d21 d22 -- layer 2
+ // \ / \ /
+ // d30 d31 -- layer 3
+ // \ /
+ // d40 -- layer 4
+ HloComputation::Builder builder("entry_computation");
+ std::vector<HloInstruction*> params;
+ for (int i = 0; i < 6; ++i) {
+ params.push_back(builder.AddInstruction(HloInstruction::CreateParameter(
+ i, f32_2x2_, /*name=*/tensorflow::strings::Printf("param%d", i))));
+ }
+ HloInstruction* d00 = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32_2x2_, HloOpcode::kDot, params[2], params[3]));
+ HloInstruction* d10 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[1], d00));
+ HloInstruction* d11 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d00, params[4]));
+ HloInstruction* d20 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, params[0], d10));
+ HloInstruction* d21 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d10, d11));
+ HloInstruction* d22 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d11, params[5]));
+ HloInstruction* d30 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d20, d21));
+ HloInstruction* d31 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d21, d22));
+ HloInstruction* d40 = builder.AddInstruction(
+ HloInstruction::CreateBinary(f32_2x2_, HloOpcode::kDot, d30, d31));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build(d40));
+
+ std::unique_ptr<StreamAssignment> assignment = AssignStreams(module);
+ // The two dots on layer 1 are concurrent.
+ EXPECT_NE(assignment->StreamNumberForHlo(*d10),
+ assignment->StreamNumberForHlo(*d11));
+ // The three dots on layer 2 are concurrent.
+ EXPECT_NE(assignment->StreamNumberForHlo(*d20),
+ assignment->StreamNumberForHlo(*d21));
+ EXPECT_NE(assignment->StreamNumberForHlo(*d20),
+ assignment->StreamNumberForHlo(*d22));
+ EXPECT_NE(assignment->StreamNumberForHlo(*d21),
+ assignment->StreamNumberForHlo(*d22));
+ // The two dots on layer 3 are concurrent.
+ EXPECT_NE(assignment->StreamNumberForHlo(*d30),
+ assignment->StreamNumberForHlo(*d31));
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc
new file mode 100644
index 0000000000..3cf5dd021a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h"
+
+#include "tensorflow/compiler/xla/map_util.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+int64 RoundUpToAlignment(int64 value) {
+ // Any address of a variable residing in global memory or returned by one of
+ // the memory allocation routines from the driver or runtime API is always
+ // aligned to at least 256 bytes.
+ // (http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses)
+ static constexpr int64 kCudaMallocAlignment = 256;
+ return (value + kCudaMallocAlignment - 1) / kCudaMallocAlignment *
+ kCudaMallocAlignment;
+}
+} // namespace
+
+TempBufferOffsets::TempBufferOffsets(
+ const BufferAssignment& buffer_assignment) {
+ total_size_of_temp_buffers_ = 0;
+ for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) {
+ const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
+ if (allocation.IsPreallocatedTempBuffer()) {
+ InsertOrDie(&buffer_index_to_offset_, i, total_size_of_temp_buffers_);
+ total_size_of_temp_buffers_ += RoundUpToAlignment(allocation.size());
+ }
+ }
+}
+
+int64 TempBufferOffsets::GetOffset(BufferAllocation::Index index) const {
+ return FindOrDie(buffer_index_to_offset_, index);
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h
new file mode 100644
index 0000000000..05aca99bf3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/temp_buffer_offsets.h
@@ -0,0 +1,47 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_
+
+#include <map>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace gpu {
+
+// GpuExecutable merges all temporary buffers into one memory block. This class
+// stores the offset of each temporary buffer in that memory block.
+class TempBufferOffsets {
+ public:
+ explicit TempBufferOffsets(const BufferAssignment& buffer_assignment);
+
+ int64 GetOffset(BufferAllocation::Index index) const;
+ int64 TotalSizeInBytes() const { return total_size_of_temp_buffers_; }
+
+ private:
+ std::map<BufferAllocation::Index, int64> buffer_index_to_offset_;
+
+ // The total size of all temporary buffers. This includes paddings that are
+ // necessary for alignment.
+ int64 total_size_of_temp_buffers_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TEMP_BUFFER_OFFSETS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h
new file mode 100644
index 0000000000..3ced348400
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/thunk.h
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+class GpuExecutable;
+
+// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
+// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
+//
+// Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable
+// to initialize and execute the invocation respectively. Its subclasses are
+// supposed to override these interfaces to launch a generated kernel or call an
+// external library function (such as operations in cuBLAS).
+//
+// This is thread-compatible.
+class Thunk {
+ public:
+ enum class Kind {
+ kConvolution,
+ kCopy,
+ kGemm,
+ kKernel,
+ kSequential,
+ kTuple,
+ kWhile,
+ };
+
+ // The hlo_instruction argument is meant to be the instruction this thunk was
+ // generated from, but Thunk never uses this argument other than to save it
+ // to Thunk::hlo_instruction, so it can be null.
+ explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
+ : kind_(kind), hlo_instruction_(hlo_instruction) {}
+ virtual ~Thunk() {}
+ Thunk(const Thunk&) = delete;
+ Thunk& operator=(const Thunk&) = delete;
+
+ Kind kind() const { return kind_; }
+ const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
+
+ // Prepares for executing the thunk. This method is called only once over
+ // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a
+ // kernel, which is the same in every execution.
+ virtual tensorflow::Status Initialize(const GpuExecutable& executable) {
+ return tensorflow::Status::OK();
+ }
+
+ // Execute the kernel for the thunk on the given stream. This method must be
+ // called after Initialize and can be called multiple times over Thunk's
+ // lifetime. Stream argument must be non-null.
+ virtual tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) = 0;
+
+ private:
+ Kind kind_;
+ const HloInstruction* hlo_instruction_;
+};
+
+// A sequence of thunks.
+using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
new file mode 100644
index 0000000000..8addcd87ea
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc
@@ -0,0 +1,163 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/thunk_schedule.h"
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+void ThunkSchedule::AddDependenciesOnTransitiveOperands(
+ const Thunk& thunk, const HloInstruction& operand,
+ const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk) {
+ if (hlo_to_thunk.count(&operand)) {
+ // If `operand` is mapped to a thunk, adds `operand` to `thunk`'s dependency
+ // list if `operand` is assigned to a different stream. As an optimization,
+ // we skip `operand`'s operands because `operand` depends on them already.
+ if (stream_assignment_->StreamNumberForHlo(operand) !=
+ stream_assignment_->StreamNumberForHlo(*thunk.hlo_instruction())) {
+ depends_on_[&thunk].push_back(FindOrDie(hlo_to_thunk, &operand));
+ }
+ } else {
+ // If `operand` doesn't need a thunk (e.g. bitcast), continue with its
+ // operands.
+ for (const auto* operand_of_operand : operand.operands()) {
+ AddDependenciesOnTransitiveOperands(thunk, *operand_of_operand,
+ hlo_to_thunk);
+ }
+ }
+}
+
+ThunkSchedule::ThunkSchedule(
+ std::unique_ptr<ThunkSequence> thunks,
+ std::unique_ptr<StreamAssignment> stream_assignment,
+ const std::vector<const HloInstruction*>& hlo_total_order)
+ : thunks_(std::move(thunks)),
+ stream_assignment_(std::move(stream_assignment)) {
+ std::unordered_map<const HloInstruction*, Thunk*> hlo_to_thunk;
+ for (const auto& thunk : *thunks_) {
+ InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get());
+ }
+
+ for (const HloInstruction* hlo : hlo_total_order) {
+ if (hlo_to_thunk.count(hlo)) {
+ thunk_total_order_.push_back(FindOrDie(hlo_to_thunk, hlo));
+ }
+ }
+
+ for (const Thunk* thunk : thunk_total_order_) {
+ const auto* dst = thunk->hlo_instruction();
+ CHECK(stream_assignment_->HasStreamAssigned(*dst));
+ for (const auto* src : dst->operands()) {
+ AddDependenciesOnTransitiveOperands(*thunk, *src, hlo_to_thunk);
+ }
+ }
+
+ RemoveRedundantDependencyEdges();
+
+ // Compute `depended_by_`, the inverse of `depends_on_`.
+ for (const auto& dependency : depends_on_) {
+ for (const auto* depended : dependency.second) {
+ depended_by_.insert(depended);
+ }
+ }
+}
+
+void ThunkSchedule::RemoveRedundantDependencyEdges() {
+ std::unordered_map<const Thunk*, int> thunk_to_total_order;
+ for (auto i = 0; i < thunk_total_order_.size(); ++i) {
+ InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
+ }
+
+ int stream_count = stream_assignment_->StreamCount();
+ // S1 S2
+ //
+ // T1<----+
+ // |
+ // T3<--+ |
+ // | | depends on
+ // T4 |
+ // |
+ // T2-+
+ //
+ // Suppose thunk T1 and T3 are scheduled on stream S1, and T2 and T4 are on
+ // stream S2. If T2 depends on T1 and T4 depends on T3, and
+ // order(T1)<order(T3)<order(T4)<order(T2), the dependency of T2 on T1 is
+ // redundant.
+ //
+ // To efficiently detect such redundancy, we leverage array `last_dependency`.
+ // last_dependency[S1][S2] indicates the last thunk (with the maximum order
+ // number) on stream S2 that thunks on S1 depends on. Therefore, if a future
+ // S1 thunk depends on a S2 thunk ordered <=last_dependency[S1][S2], that is a
+ // redundant dependency edge.
+ Array2D<int> last_dependency(stream_count, stream_count, -1);
+ for (const Thunk* dst : thunk_total_order_) {
+ if (!depends_on_.count(dst)) {
+ continue;
+ }
+
+ int dst_stream =
+ stream_assignment_->StreamNumberForHlo(*dst->hlo_instruction());
+ std::list<const Thunk*>& sources = FindOrDie(depends_on_, dst);
+ for (auto iter = sources.begin(); iter != sources.end();) {
+ const Thunk* src = *iter;
+ // `dst` depends on `src`.
+ int src_stream =
+ stream_assignment_->StreamNumberForHlo(*src->hlo_instruction());
+ int src_order = FindOrDie(thunk_to_total_order, src);
+ if (src_order <= last_dependency(dst_stream, src_stream)) {
+ iter = sources.erase(iter);
+ } else {
+ last_dependency(dst_stream, src_stream) = src_order;
+ ++iter;
+ }
+ }
+ if (sources.empty()) {
+ depends_on_.erase(dst);
+ }
+ }
+}
+
+const std::list<const Thunk*>& ThunkSchedule::DependsOn(
+ const Thunk* thunk) const {
+ if (depends_on_.count(thunk)) {
+ return FindOrDie(depends_on_, thunk);
+ } else {
+ return empty_thunk_list_;
+ }
+}
+
+string ThunkSchedule::ToString() const {
+ string result = "Total order:\n";
+ for (Thunk* thunk : thunk_total_order_) {
+ tensorflow::strings::StrAppend(&result, "\t",
+ thunk->hlo_instruction()->ToString(), "\n");
+ }
+ tensorflow::strings::StrAppend(&result, "Dependencies:\n");
+ for (const auto& entry : depends_on_) {
+ const Thunk* dependent = entry.first;
+ for (const Thunk* dependency : entry.second) {
+ tensorflow::strings::StrAppend(
+ &result, "\t", dependent->hlo_instruction()->name(), " depends on ",
+ dependency->hlo_instruction()->name(), "\n");
+ }
+ }
+ return result;
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h
new file mode 100644
index 0000000000..d3352994f8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_
+
+#include <list>
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace gpu {
+
+// Encapsulates in which order and on which streams the thunks are executed. A
+// schedule contains
+//
+// 1. A stream assignment indicating which stream each thunk is executed on.
+//
+// 2. A total order of all thunks. If A is ordered before B and they are
+// assigned to the same stream, then A completes before B starts. If A is
+// ordered before B and they are on different streams, their actual execution
+// order is not determined.
+//
+// 3. A set of dependency edges. If A and B are scheduled on different streams
+// and A has to complete before B starts (e.g. A produces an input of B), then B
+// "depends" on A.
+class ThunkSchedule {
+ public:
+ ThunkSchedule(std::unique_ptr<ThunkSequence> thunks,
+ std::unique_ptr<StreamAssignment> stream_assignment,
+ const std::vector<const HloInstruction*>& hlo_total_order);
+
+ // Returns the total order of executing all the thunks.
+ const std::vector<Thunk*>& TotalOrder() const { return thunk_total_order_; }
+
+ // Thunks that `thunk` depends on.
+ const std::list<const Thunk*>& DependsOn(const Thunk* thunk) const;
+ // Whether `thunk` is depended by another thunk.
+ bool Depended(const Thunk* thunk) const { return depended_by_.count(thunk); }
+
+ // Delegates to StreamAssignment.
+ int StreamCount() const { return stream_assignment_->StreamCount(); }
+ int StreamNumberForHlo(const HloInstruction& hlo) const {
+ return stream_assignment_->StreamNumberForHlo(hlo);
+ }
+
+ string ToString() const;
+
+ private:
+ void RemoveRedundantDependencyEdges();
+
+ // Adds `operand` and its transitive operands to the dependency list of
+ // `thunk`.
+ //
+ // Precondition: `operand` is a non-trivial (i.e. excluding
+ // thunk.hlo_instruction() itself) transitive operand of
+ // thunk.hlo_instruction().
+ void AddDependenciesOnTransitiveOperands(
+ const Thunk& thunk, const HloInstruction& operand,
+ const std::unordered_map<const HloInstruction*, Thunk*>& hlo_to_thunk);
+
+ std::unique_ptr<ThunkSequence> thunks_;
+ std::vector<Thunk*> thunk_total_order_;
+
+ std::unordered_map<const Thunk*, std::list<const Thunk*>> depends_on_;
+ std::set<const Thunk*> depended_by_;
+ std::list<const Thunk*> empty_thunk_list_;
+
+ std::unique_ptr<StreamAssignment> stream_assignment_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_SCHEDULE_H_
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
new file mode 100644
index 0000000000..323775b3e8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.cc
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
+
+#include "tensorflow/compiler/xla/util.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace gpu {
+
+tensorflow::Status TupleThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations, se::Stream* stream) {
+ std::vector<void*> tuple_element_buffer_addresses;
+ for (BufferAllocation::Index tuple_element_buffer : tuple_element_buffers_) {
+ tuple_element_buffer_addresses.push_back(
+ buffer_allocations.GetDeviceAddress(tuple_element_buffer).opaque());
+ }
+ se::DeviceMemory<void*> dest_buffer_address(
+ buffer_allocations.GetDeviceAddress(dest_buffer_));
+
+ auto host_size = tuple_element_buffer_addresses.size() * sizeof(void*);
+ if (!stream
+ ->ThenMemcpy(&dest_buffer_address,
+ tuple_element_buffer_addresses.data(), host_size)
+ .ok()) {
+ return InternalError(
+ "Unable to launch MemcpyH2D from %p to %p with size %lu",
+ tuple_element_buffer_addresses.data(), dest_buffer_address.opaque(),
+ sizeof(void*) * tuple_element_buffer_addresses.size());
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/tuple_thunk.h b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
new file mode 100644
index 0000000000..ca0404286f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/tuple_thunk.h
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TUPLE_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TUPLE_THUNK_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// A thunk that copies the addresses of tuple elements to the buffer of the
+// tuple. This avoids emitting kernels that may suffer from the parameter space
+// issue (b/31336476).
+class TupleThunk : public Thunk {
+ public:
+ TupleThunk(tensorflow::gtl::ArraySlice<BufferAllocation::Index>
+ tuple_element_buffers,
+ BufferAllocation::Index dest_buffer,
+ const HloInstruction* hlo_instruction)
+ : Thunk(Kind::kTuple, hlo_instruction),
+ tuple_element_buffers_(tuple_element_buffers.begin(),
+ tuple_element_buffers.end()),
+ dest_buffer_(dest_buffer) {}
+
+ TupleThunk(const TupleThunk&) = delete;
+ TupleThunk& operator=(const TupleThunk&) = delete;
+
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ std::vector<BufferAllocation::Index> tuple_element_buffers_;
+ const BufferAllocation::Index dest_buffer_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_TUPLE_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
new file mode 100644
index 0000000000..36883e4920
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+WhileThunk::WhileThunk(BufferAllocation::Index condition_result_buffer_index,
+ std::unique_ptr<ThunkSequence> condition_thunk_sequence,
+ std::unique_ptr<ThunkSequence> body_thunk_sequence,
+ const HloInstruction* hlo)
+ : Thunk(Kind::kWhile, hlo),
+ condition_result_buffer_index_(condition_result_buffer_index),
+ condition_thunk_sequence_(MakeUnique<SequentialThunk>(
+ std::move(*condition_thunk_sequence), hlo)),
+ body_thunk_sequence_(
+ MakeUnique<SequentialThunk>(std::move(*body_thunk_sequence), hlo)) {}
+
+tensorflow::Status WhileThunk::Initialize(const GpuExecutable& executable) {
+ TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable));
+ TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status WhileThunk::ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) {
+
+ perftools::gputools::DeviceMemoryBase condition_result_data =
+ buffer_allocations.GetDeviceAddress(condition_result_buffer_index_);
+
+ while (true) {
+ // Invoke thunk sequence for while 'condition' computation.
+ TF_RETURN_IF_ERROR(
+ condition_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+
+ // Copy the result of condition computation and break the loop if 'false'.
+ bool condition_result;
+ stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool));
+ if (!stream->BlockHostUntilDone()) {
+ return InternalError(
+ "Failed to complete all kernels launched on stream %p", stream);
+ }
+
+ if (!condition_result) {
+ break;
+ }
+
+ // Invoke thunk sequence for while 'body' computation.
+ TF_RETURN_IF_ERROR(
+ body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream));
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h
new file mode 100644
index 0000000000..1658cdaf87
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_THUNK_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_THUNK_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
+#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/thunk.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+namespace gpu {
+
+// WhileThunk implements the while instruction on GPU by invoking a thunk
+// sequence for the while 'condition' computation, and (conditionally) another
+// thunk sequence for the while 'body' computation. WhileThunk assumes that
+// buffers for the following set of while-related instructions share the same
+// allocation:
+// init, condition.parameter, body.parameter, body.root, while.result
+// WhileThunk synchronizes the stream to test the result of the 'condition'
+// computation.
+class WhileThunk : public Thunk {
+ public:
+ // Constructs a WhileThunk to compute while instruction 'hlo'.
+ WhileThunk(BufferAllocation::Index condition_result_buffer_index,
+ std::unique_ptr<ThunkSequence> condition_thunk_sequence,
+ std::unique_ptr<ThunkSequence> body_thunk_sequence,
+ const HloInstruction* hlo);
+ WhileThunk(const WhileThunk&) = delete;
+ WhileThunk& operator=(const WhileThunk&) = delete;
+
+ tensorflow::Status Initialize(const GpuExecutable& executable) override;
+ tensorflow::Status ExecuteOnStream(
+ const BufferAllocations& buffer_allocations,
+ perftools::gputools::Stream* stream) override;
+
+ private:
+ BufferAllocation::Index condition_result_buffer_index_;
+ std::unique_ptr<SequentialThunk> condition_thunk_sequence_;
+ std::unique_ptr<SequentialThunk> body_thunk_sequence_;
+};
+
+} // namespace gpu
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_THUNK_H_
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
new file mode 100644
index 0000000000..7ebc929ced
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc
@@ -0,0 +1,532 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+namespace gpu {
+
+namespace {
+
+// MatcherBase is a base class that provides common functionality for
+// sub-classes which match specific target sub-computations (i.e. loop
+// induction variable initialization, comparison and update).
+// TODO(b/33483676) Use an expression tree to specify computations to pattern
+// match for while transformations.
+class MatcherBase {
+ public:
+ enum State {
+ ADD,
+ PRED,
+ CONST,
+ COPY,
+ GTE0,
+ GTE1,
+ PARAM,
+ TUPLE0,
+ TUPLE1,
+ WHILE
+ };
+
+ // Initializes MatcherBase with 'computation' and initial state 'state'.
+ explicit MatcherBase(const HloComputation* computation, State state)
+ : computation_(computation), state_(state) {}
+
+ // Initializes MatcherBase with 'computation', initial state 'state', and
+ // value for 'tuple_index'.
+ MatcherBase(const HloComputation* computation, State state,
+ const int64 tuple_index)
+ : computation_(computation), state_(state), tuple_index_(tuple_index) {}
+ virtual ~MatcherBase() {}
+
+ // Overridden by sub-classes to match specific target sub-computations.
+ // Returns OK if target sub-computation was matched, error status otherwise.
+ virtual tensorflow::Status Run() = 0;
+
+ // Matches a Constant instruction of integral type, parses its value, and
+ // stores the value in 'const_value_'.
+ // Returns OK on success, error status otherwise.
+ tensorflow::Status MatchConst() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction->opcode() != HloOpcode::kConstant) {
+ return InvalidArgument("Expected constant instruction.");
+ }
+ if (!IsSupportedIntType(instruction->shape())) {
+ return InvalidArgument("Expected constant of integral type.");
+ }
+ const Literal& literal = instruction->literal();
+ PrimitiveType type = literal.shape().element_type();
+ if (type == S32) {
+ const_value_ =
+ static_cast<int64>(LiteralUtil::GetFirstElement<int32>(literal));
+ } else if (type == S64) {
+ const_value_ = LiteralUtil::GetFirstElement<int64>(literal);
+ } else {
+ return InvalidArgument("Must use S32 or S64 integral types.");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ // Matches a Copy instruction.
+ // Pushes its operand on the stack for subsequent processing.
+ // Returns OK on success, error status otherwise.
+ tensorflow::Status MatchCopy() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction->opcode() != HloOpcode::kCopy) {
+ return InvalidArgument("Expectecd Copy.");
+ }
+ stack_.push_back(instruction->operand(0));
+ return tensorflow::Status::OK();
+ }
+
+ // Matches a GetTupleElement instruction and either parses its 'tuple_index'
+ // parameter (if not initialized) or compares its 'tuple_index' with the
+ // previously initialized value.
+ // Pushes its operand on the stack for subsequent processing.
+ // Returns OK on success, error status otherwise.
+ tensorflow::Status MatchGetTupleElement() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction->opcode() != HloOpcode::kGetTupleElement) {
+ return InvalidArgument("Expected GetTupleElement instruction.");
+ }
+ if (!IsSupportedIntType(instruction->shape())) {
+ return InvalidArgument("GetTupleElement instruction be integral type.");
+ }
+ if (tuple_index_ == -1) {
+ tuple_index_ = instruction->tuple_index();
+ } else if (tuple_index_ != instruction->tuple_index()) {
+ return InvalidArgument("Invalid tuple index");
+ }
+ stack_.push_back(instruction->operand(0));
+ return tensorflow::Status::OK();
+ }
+
+ // Matches a Parameter instruction and compares it with 'computation_'
+ // parameter instruction at index 0.
+ // Returns OK on success, error status otherwise.
+ tensorflow::Status MatchParameter() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction != computation_->parameter_instruction(0)) {
+ return InvalidArgument("Expected Parameter instruction.");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ // Matches a Tuple instruction.
+ // Pushes operand at 'tuple_index_' on the stack for subsequent processing.
+ // Returns OK on success, error status otherwise.
+ tensorflow::Status MatchTuple() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction->opcode() != HloOpcode::kTuple) {
+ return InvalidArgument("Expected Tuple instruction.");
+ }
+ stack_.push_back(instruction->operand(tuple_index_));
+ return tensorflow::Status::OK();
+ }
+
+ protected:
+ const HloComputation* computation_;
+ State state_;
+ int64 tuple_index_ = -1;
+ int64 const_value_ = -1;
+ std::vector<const HloInstruction*> stack_;
+
+ private:
+ bool IsSupportedIntType(const Shape& shape) {
+ return shape.element_type() == S32 || shape.element_type() == S64;
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MatcherBase);
+};
+
+// WhileConditionComputationMatcher matches one of the following two
+// target While condition computations:
+//
+// Case 1: LessThan
+//
+// PARAM
+// |
+// |
+// GTE0 CONST
+// \ /
+// \ /
+// PRED
+//
+//
+// Case 2: GreaterThan
+//
+// PARAM
+// |
+// |
+// CONST GTE0
+// \ /
+// \ /
+// PRED
+//
+// If we do not successufully match one of the two target cases, we return a
+// descriptive error status.
+//
+// If we do successfully match one of the cases, we parse and store the
+// following two pieces of information from the computation:
+//
+// *) 'tuple_index':
+// *) The loop induction variable tuple_index from the GetTupleElement
+// instruction of the matched computation.
+// *) Used in subsequent matching passes of while init operand and body
+// computations to select loop induction variable tuple element.
+//
+// *) 'loop_limit':
+// *) The integral value from Constant root operand in matched computation.
+// *) Used as the constant for the loop limit.
+//
+class WhileConditionComputationMatcher : public MatcherBase {
+ public:
+ WhileConditionComputationMatcher(const HloComputation* computation)
+ : MatcherBase(computation, PRED) {
+ stack_.push_back(computation_->root_instruction());
+ }
+
+ // Loop attempting to match target computation.
+ tensorflow::Status Run() {
+ while (!stack_.empty()) {
+ switch (state_) {
+ case PRED: {
+ TF_RETURN_IF_ERROR(MatchPred());
+ break;
+ }
+ case CONST: {
+ TF_RETURN_IF_ERROR(MatchConst());
+ state_ = GTE0;
+ break;
+ }
+ case GTE0: {
+ TF_RETURN_IF_ERROR(MatchGetTupleElement());
+ state_ = PARAM;
+ break;
+ }
+ case PARAM: {
+ TF_RETURN_IF_ERROR(MatchParameter());
+ break;
+ }
+ default:
+ return InvalidArgument("Unexpected state.");
+ }
+ }
+ return tensorflow::Status::OK();
+ }
+
+ int64 loop_limit() const { return const_value_; }
+ int64 tuple_index() const { return tuple_index_; }
+
+ private:
+ tensorflow::Status MatchPred() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ // Push operands in canonical order: GetTupleElement, Constant.
+ if (instruction->opcode() == HloOpcode::kLt) {
+ stack_.push_back(instruction->operand(0));
+ stack_.push_back(instruction->operand(1));
+ } else if (instruction->opcode() == HloOpcode::kGt) {
+ stack_.push_back(instruction->operand(1));
+ stack_.push_back(instruction->operand(0));
+ } else {
+ return InvalidArgument("Condition must be LT or GT.");
+ }
+ state_ = CONST;
+ return tensorflow::Status::OK();
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WhileConditionComputationMatcher);
+};
+
+// WhileInitOperandMatcher matches one of the following two target while
+// init operand sub-computations:
+//
+// Case 1: No copy.
+//
+// CONST // Tuple.operand(tuple_index)
+// |
+// TUPLE0 // While.operand(0)
+// |
+// WHILE
+//
+// Case 2: With copy.
+//
+// CONST // Tuple1.operand(tuple_index)
+// |
+// TUPLE1 // GetTupleElement.operand(0)
+// |
+// GTE0 // Copy.operand(0)
+// |
+// COPY // Tuple0.operand(tuple_index)
+// |
+// TUPLE0 // While.operand(0)
+// |
+// While
+//
+class WhileInitOperandMatcher : public MatcherBase {
+ public:
+ WhileInitOperandMatcher(const HloInstruction* while_hlo,
+ const int64 tuple_index)
+ : MatcherBase(while_hlo->parent(), WHILE, tuple_index) {
+ stack_.push_back(while_hlo);
+ }
+
+ // Loop attempting to match target computation.
+ tensorflow::Status Run() {
+ while (!stack_.empty()) {
+ switch (state_) {
+ case WHILE: {
+ TF_RETURN_IF_ERROR(MatchWhile());
+ break;
+ }
+ case TUPLE0: {
+ TF_RETURN_IF_ERROR(MatchTuple());
+ TF_RETURN_IF_ERROR(PostMatchTuple());
+ break;
+ }
+ case TUPLE1: {
+ TF_RETURN_IF_ERROR(MatchTuple());
+ state_ = CONST;
+ break;
+ }
+ case CONST: {
+ TF_RETURN_IF_ERROR(MatchConst());
+ break;
+ }
+ case COPY: {
+ TF_RETURN_IF_ERROR(MatchCopy());
+ state_ = GTE0;
+ break;
+ }
+ case GTE0: {
+ TF_RETURN_IF_ERROR(MatchGetTupleElement());
+ state_ = TUPLE1;
+ break;
+ }
+ default:
+ return InvalidArgument("Unexpected state.");
+ }
+ }
+ return tensorflow::Status::OK();
+ }
+
+ int64 loop_start() const { return const_value_; }
+
+ private:
+ tensorflow::Status MatchWhile() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction->opcode() != HloOpcode::kWhile) {
+ return InvalidArgument("While init match expected while instruction.");
+ }
+ // Push while 'init' operand.
+ stack_.push_back(instruction->operand(0));
+ state_ = TUPLE0;
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Status PostMatchTuple() {
+ // Transition to the next state based on matched tuple operand.
+ const HloInstruction* operand = stack_.back();
+ if (operand->opcode() == HloOpcode::kConstant) {
+ state_ = CONST;
+ } else if (operand->opcode() == HloOpcode::kCopy) {
+ state_ = COPY;
+ } else {
+ return InvalidArgument("Expected constant or copy tuple operand.");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WhileInitOperandMatcher);
+};
+
+// WhileBodyComputationMatcher matches one of the following two target
+// sub-computations:
+//
+// Case 1:
+//
+// PARAM
+// |
+// CONST GTE1
+// \ /
+// ADD // Tuple.operand(tuple_index).
+// |
+// TUPLE0 (root)
+//
+// Case 2:
+//
+// PARAM
+// |
+// CONST GTE1
+// \ /
+// ADD // Tuple.operand(tuple_index).
+// |
+// TUPLE1
+// |
+// GTE0
+// |
+// COPY
+// |
+// TUPLE0 (root)
+//
+// Note that the induction variable tuple element can have multiple users
+// in the while loop body computation, but one update path.
+// Matching proceeds from the root[tuple_index] to param[tuple_index].
+//
+class WhileBodyComputationMatcher : public MatcherBase {
+ public:
+ WhileBodyComputationMatcher(const HloComputation* computation,
+ const int64 tuple_index)
+ : MatcherBase(computation, TUPLE0, tuple_index) {
+ stack_.push_back(computation_->root_instruction());
+ }
+
+ // Loop attempting to match target computation.
+ tensorflow::Status Run() {
+ while (!stack_.empty()) {
+ switch (state_) {
+ case TUPLE0: {
+ TF_RETURN_IF_ERROR(MatchTuple());
+ TF_RETURN_IF_ERROR(PostMatchTuple());
+ break;
+ }
+ case TUPLE1: {
+ TF_RETURN_IF_ERROR(MatchTuple());
+ state_ = ADD;
+ break;
+ }
+ case ADD: {
+ TF_RETURN_IF_ERROR(MatchAdd());
+ break;
+ }
+ case CONST: {
+ TF_RETURN_IF_ERROR(MatchConst());
+ state_ = GTE1;
+ break;
+ }
+ case COPY: {
+ TF_RETURN_IF_ERROR(MatchCopy());
+ state_ = GTE0;
+ break;
+ }
+ case GTE0: {
+ TF_RETURN_IF_ERROR(MatchGetTupleElement());
+ state_ = TUPLE1;
+ break;
+ }
+ case GTE1: {
+ TF_RETURN_IF_ERROR(MatchGetTupleElement());
+ state_ = PARAM;
+ break;
+ }
+ case PARAM: {
+ TF_RETURN_IF_ERROR(MatchParameter());
+ break;
+ }
+ default:
+ return InvalidArgument("Unexpected state.");
+ }
+ }
+ return tensorflow::Status::OK();
+ }
+
+ int64 loop_increment() const { return const_value_; }
+
+ private:
+ tensorflow::Status MatchAdd() {
+ const HloInstruction* instruction = stack_.back();
+ stack_.pop_back();
+ if (instruction->opcode() != HloOpcode::kAdd) {
+ return InvalidArgument("Expected Add induction variable update.");
+ }
+ // Push in canonical order: GetTupleElement, Constant.
+ if (instruction->operand(0)->opcode() == HloOpcode::kConstant &&
+ instruction->operand(1)->opcode() == HloOpcode::kGetTupleElement) {
+ stack_.push_back(instruction->operand(1));
+ stack_.push_back(instruction->operand(0));
+ } else if (instruction->operand(1)->opcode() == HloOpcode::kConstant &&
+ instruction->operand(0)->opcode() ==
+ HloOpcode::kGetTupleElement) {
+ stack_.push_back(instruction->operand(0));
+ stack_.push_back(instruction->operand(1));
+ } else {
+ return InvalidArgument("Invalid types for Add operands");
+ }
+ state_ = CONST;
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Status PostMatchTuple() {
+ // Transition to the next state based on matched tuple operand.
+ const HloInstruction* operand = stack_.back();
+ if (operand->opcode() == HloOpcode::kAdd) {
+ state_ = ADD;
+ } else if (operand->opcode() == HloOpcode::kCopy) {
+ state_ = COPY;
+ } else {
+ return InvalidArgument("Expected add or copy tuple operand.");
+ }
+ return tensorflow::Status::OK();
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(WhileBodyComputationMatcher);
+};
+
+} // namespace
+
+StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
+ const HloInstruction* while_hlo) {
+ if (while_hlo->opcode() != HloOpcode::kWhile) {
+ return InvalidArgument("Expected While instruction.");
+ }
+
+ WhileConditionComputationMatcher cond_matcher(while_hlo->while_condition());
+ TF_RETURN_IF_ERROR(cond_matcher.Run());
+
+ WhileInitOperandMatcher init_matcher(while_hlo, cond_matcher.tuple_index());
+ TF_RETURN_IF_ERROR(init_matcher.Run());
+
+ WhileBodyComputationMatcher body_matcher(while_hlo->while_body(),
+ cond_matcher.tuple_index());
+ TF_RETURN_IF_ERROR(body_matcher.Run());
+
+ // Check for valid For loop parameters.
+ if (init_matcher.loop_start() >= cond_matcher.loop_limit()) {
+ return InvalidArgument("Loop start must be less than loop limit.");
+ }
+ if (body_matcher.loop_increment() <= 0) {
+ return InvalidArgument("Loop increment must greater than zero.");
+ }
+ return std::make_tuple(init_matcher.loop_start(), cond_matcher.loop_limit(),
+ body_matcher.loop_increment());
+}
+
+} // namespace gpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h
new file mode 100644
index 0000000000..a4f527fce0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+namespace gpu {
+
+// Runs an analysis of the while loop instruction 'while_hlo' (and its
+// associated sub-computations) to determine if it can be transformed into an
+// equivalent "for" loop with the following "for" loop parameters:
+//
+// *) 'loop_start': loop induction variable starting value.
+// *) 'loop_limit': loop induction variable limit value.
+// *) 'loop_increment': loop induction variable per-iteration increment value.
+//
+// Returns an std::tuple = (loop_start, loop_limit, loop_increment) on success.
+// The values in the returned tuple are values extracted from the 'while_hlo'
+// operand (and its sub-computations) during analysis.
+// Returns an error status on failure.
+StatusOr<std::tuple<int64, int64, int64>> CanTransformWhileToFor(
+ const HloInstruction* while_hlo);
+
+} // namespace gpu
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
new file mode 100644
index 0000000000..799b30d21e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc
@@ -0,0 +1,218 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/gpu/while_transformer.h"
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+namespace {
+
+class WhileTransformerTest : public HloTestBase {
+ protected:
+ WhileTransformerTest()
+ : module_(TestName()),
+ induction_variable_shape_(ShapeUtil::MakeShape(S32, {})),
+ data_shape_(ShapeUtil::MakeShape(F32, {8})),
+ loop_state_shape_(ShapeUtil::MakeTupleShape(
+ {induction_variable_shape_, data_shape_})),
+ condition_result_shape_(ShapeUtil::MakeShape(PRED, {})) {}
+
+ std::unique_ptr<HloComputation> BuildConditionComputation(
+ const int64 tuple_index, const int64 limit) {
+ auto builder = HloComputation::Builder(TestName() + ".Condition");
+ auto limit_const = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(limit)));
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ auto induction_variable =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ limit_const->shape(), loop_state, tuple_index));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(condition_result_shape_, HloOpcode::kLt,
+ induction_variable, limit_const));
+ return builder.Build();
+ }
+
+ std::unique_ptr<HloComputation> BuildBodyComputation(
+ const int64 ind_var_tuple_index, const int64 data_tuple_index,
+ const int64 increment, bool insert_copies = false) {
+ auto builder = HloComputation::Builder(TestName() + ".Body");
+ // Create param instruction to access loop state.
+ auto loop_state = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state"));
+ // Update the induction variable GTE(ind_var_tuple_index).
+ auto induction_variable =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ induction_variable_shape_, loop_state, ind_var_tuple_index));
+ auto inc = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<int32>(increment)));
+ auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
+ induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc));
+ // Update data GTE(data_tuple_index).
+ auto data = builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ data_shape_, loop_state, data_tuple_index));
+ // Use 'induction_variable' in computation with no path to output tuple.
+ auto update = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, induction_variable, {8}));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ data_shape_, HloOpcode::kAdd, data, update));
+ // Create output Tuple.
+ auto tuple0 =
+ ind_var_tuple_index == 0
+ ? builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}))
+ : builder.AddInstruction(HloInstruction::CreateTuple({add1, add0}));
+ if (insert_copies) {
+ InsertTupleElementCopies(ind_var_tuple_index, tuple0, &builder);
+ }
+ return builder.Build();
+ }
+
+ HloInstruction* BuildWhileInstruction(HloComputation* condition,
+ HloComputation* body,
+ const int64 ind_var_tuple_index,
+ const int64 ind_var_init,
+ bool insert_copies = false) {
+ auto builder = HloComputation::Builder(TestName() + ".While");
+ auto induction_var_init =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR0<int32>(ind_var_init)));
+ auto data_init = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(
+ {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f})));
+ auto loop_state_init =
+ ind_var_tuple_index == 0
+ ? builder.AddInstruction(
+ HloInstruction::CreateTuple({induction_var_init, data_init}))
+ : builder.AddInstruction(
+ HloInstruction::CreateTuple({data_init, induction_var_init}));
+ if (insert_copies) {
+ loop_state_init = InsertTupleElementCopies(ind_var_tuple_index,
+ loop_state_init, &builder);
+ }
+ auto while_hlo = builder.AddInstruction(HloInstruction::CreateWhile(
+ loop_state_shape_, condition, body, loop_state_init));
+ module_.AddEntryComputation(builder.Build());
+ return while_hlo;
+ }
+
+ HloInstruction* InsertTupleElementCopies(const int64 ind_var_tuple_index,
+ HloInstruction* tuple0,
+ HloComputation::Builder* builder) {
+ auto gte0 = builder->AddInstruction(HloInstruction::CreateGetTupleElement(
+ induction_variable_shape_, tuple0, ind_var_tuple_index));
+ const int64 gte1_tuple_index = ind_var_tuple_index == 0 ? 1 : 0;
+ auto gte1 = builder->AddInstruction(HloInstruction::CreateGetTupleElement(
+ data_shape_, tuple0, gte1_tuple_index));
+ // Insert copies.
+ auto copy0 = builder->AddInstruction(
+ HloInstruction::CreateUnary(gte0->shape(), HloOpcode::kCopy, gte0));
+ auto copy1 = builder->AddInstruction(
+ HloInstruction::CreateUnary(gte1->shape(), HloOpcode::kCopy, gte1));
+
+ return ind_var_tuple_index == 0
+ ? builder->AddInstruction(
+ HloInstruction::CreateTuple({copy0, copy1}))
+ : builder->AddInstruction(
+ HloInstruction::CreateTuple({copy1, copy0}));
+ }
+
+ HloModule module_;
+ Shape induction_variable_shape_;
+ Shape data_shape_;
+ Shape loop_state_shape_;
+ Shape condition_result_shape_;
+};
+
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(0, 10));
+ auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
+ auto result =
+ gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 0, 0));
+ EXPECT_TRUE(result.ok());
+ auto tuple = result.ConsumeValueOrDie();
+ EXPECT_EQ(0, std::get<0>(tuple));
+ EXPECT_EQ(10, std::get<1>(tuple));
+ EXPECT_EQ(1, std::get<2>(tuple));
+}
+
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0_WithBodyCopies) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(0, 10));
+ auto body =
+ module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1, true));
+ auto result =
+ gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 0, 0));
+ EXPECT_TRUE(result.ok());
+ auto tuple = result.ConsumeValueOrDie();
+ EXPECT_EQ(0, std::get<0>(tuple));
+ EXPECT_EQ(10, std::get<1>(tuple));
+ EXPECT_EQ(1, std::get<2>(tuple));
+}
+
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement0_WithInitCopies) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(0, 10));
+ auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
+ auto result = gpu::CanTransformWhileToFor(
+ BuildWhileInstruction(condition, body, 0, 0, true));
+ EXPECT_TRUE(result.ok());
+ auto tuple = result.ConsumeValueOrDie();
+ EXPECT_EQ(0, std::get<0>(tuple));
+ EXPECT_EQ(10, std::get<1>(tuple));
+ EXPECT_EQ(1, std::get<2>(tuple));
+}
+
+TEST_F(WhileTransformerTest, InductionVariableAtTupleElement1) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(1, 10));
+ auto body = module_.AddEmbeddedComputation(BuildBodyComputation(1, 0, 1));
+ auto result =
+ gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 1, 0));
+ EXPECT_TRUE(result.ok());
+ auto tuple = result.ConsumeValueOrDie();
+ EXPECT_EQ(0, std::get<0>(tuple));
+ EXPECT_EQ(10, std::get<1>(tuple));
+ EXPECT_EQ(1, std::get<2>(tuple));
+}
+
+TEST_F(WhileTransformerTest, InvalidLoopLimit) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(0, 5));
+ auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, 1));
+ auto result = gpu::CanTransformWhileToFor(
+ BuildWhileInstruction(condition, body, 0, 10));
+ EXPECT_FALSE(result.ok());
+ EXPECT_MATCH(
+ result.status().error_message(),
+ testing::ContainsRegex("Loop start must be less than loop limit."));
+}
+
+TEST_F(WhileTransformerTest, InvalidLoopIncrement) {
+ auto condition =
+ module_.AddEmbeddedComputation(BuildConditionComputation(0, 10));
+ auto body = module_.AddEmbeddedComputation(BuildBodyComputation(0, 1, -1));
+ auto result =
+ gpu::CanTransformWhileToFor(BuildWhileInstruction(condition, body, 0, 0));
+ EXPECT_FALSE(result.ok());
+ EXPECT_MATCH(
+ result.status().error_message(),
+ testing::ContainsRegex("Loop increment must greater than zero."));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
new file mode 100644
index 0000000000..cd00a41a03
--- /dev/null
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -0,0 +1,165 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Example HLO graph which demonstrates Graphviz dumper for HLO
+// computations. When run, pushes the example DOT graph to the Graphviz service
+// and prints the URL. Useful for seeing effect of changes to the graph
+// generation code.
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// Adds a computation to the given HLO module which adds a scalar constant to
+// its parameter and returns the result.
+HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) {
+ auto builder =
+ HloComputation::Builder(tensorflow::strings::StrCat("add_", addend));
+ auto x_value = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "x_value"));
+ auto half = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.5)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ half->shape(), HloOpcode::kAdd, x_value, half));
+ return module->AddEmbeddedComputation(builder.Build());
+}
+
+// Adds a computation to the given HLO module which sums its two parameters and
+// returns the result.
+HloComputation* ScalarSumComputation(HloModule* module) {
+ auto builder = HloComputation::Builder("add");
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "rhs"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
+ return module->AddEmbeddedComputation(builder.Build());
+}
+
+// Adds a computation to the given HLO module which forwards its argument to a
+// kCall instruction which then calls the given computation.
+HloComputation* CallForwardingComputation(HloComputation* computation,
+ HloModule* module) {
+ auto builder = HloComputation::Builder("call_forward");
+ auto arg = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "arg"));
+ builder.AddInstruction(
+ HloInstruction::CreateCall(arg->shape(), {arg}, computation));
+ return module->AddEmbeddedComputation(builder.Build());
+}
+
+// Create a large, arbitrary computation with many different kinds of
+// instructions. Sets the computation as the entry to an HLO module and returns
+// the module.
+std::unique_ptr<HloModule> MakeBigGraph() {
+ auto module = MakeUnique<HloModule>("BigGraph");
+
+ auto builder = HloComputation::Builder("TestBigGraphvizGraph");
+
+ // Shapes used in the computation.
+ auto mshape = ShapeUtil::MakeShape(F32, {3, 5});
+ auto vshape = ShapeUtil::MakeShape(F32, {3});
+ auto sshape = ShapeUtil::MakeShape(F32, {3});
+
+ // Create a set of parameter instructions.
+ auto param_v0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, vshape, "foo"));
+ auto param_v1 =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, vshape, "bar"));
+ auto param_v2 =
+ builder.AddInstruction(HloInstruction::CreateParameter(2, vshape, "baz"));
+ auto param_s =
+ builder.AddInstruction(HloInstruction::CreateParameter(3, sshape, "qux"));
+ auto param_m =
+ builder.AddInstruction(HloInstruction::CreateParameter(4, mshape, "zzz"));
+
+ // Add an arbitrary expression of different instructions.
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0));
+ auto clamp = builder.AddInstruction(HloInstruction::CreateTernary(
+ vshape, HloOpcode::kClamp, copy, param_v1, param_v2));
+ auto dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({dot, param_s, clamp}));
+ auto scalar = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(sshape, tuple, 2));
+ auto add_one = AddScalarConstantComputation(1.0, module.get());
+ auto rng = builder.AddInstruction(
+ HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m}));
+ auto one = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto add_computation = ScalarSumComputation(module.get());
+ builder.AddInstruction(
+ HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation));
+ auto map1 = builder.AddInstruction(
+ HloInstruction::CreateMap(sshape, {scalar}, add_one));
+ auto map2 = builder.AddInstruction(
+ HloInstruction::CreateMap(sshape, {map1}, add_one));
+ auto map3 = builder.AddInstruction(
+ HloInstruction::CreateMap(sshape, {map2}, add_one));
+
+ // Create a fusion instruction containing the chain of map instructions.
+ auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
+ sshape, HloInstruction::FusionKind::kLoop, map3));
+ fusion->FuseInstruction(map2);
+ fusion->FuseInstruction(map1);
+
+ // Add a random trace instruction.
+ builder.AddInstruction(HloInstruction::CreateTrace("trace", dot));
+
+ // Add a call instruction will calls the call-forwarding computation to call
+ // another computation.
+ auto call_computation = CallForwardingComputation(add_one, module.get());
+ builder.AddInstruction(
+ HloInstruction::CreateCall(fusion->shape(), {fusion}, call_computation));
+
+ module->AddEntryComputation(builder.Build());
+ return module;
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ auto module = xla::MakeBigGraph();
+
+ printf("Graph URL: %s\n",
+ xla::hlo_graph_dumper::DumpGraph(
+ *module->entry_computation(), "Example computation",
+ /*show_addresses=*/false, /*show_layouts=*/false)
+ .c_str());
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
new file mode 100644
index 0000000000..4b7d795cc6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -0,0 +1,520 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+
+#include <stddef.h>
+#include <algorithm>
+#include <functional>
+#include <list>
+#include <queue>
+#include <set>
+#include <sstream>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+std::unique_ptr<HloComputation> HloComputation::Builder::Build(
+ HloInstruction* root_instruction) {
+ int parameter_count = 0;
+ for (auto& instruction : instructions_) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ parameter_count++;
+ }
+ }
+ // If root_instruction is not specified use the last added instruction.
+ HloInstruction* root =
+ root_instruction ? root_instruction : last_added_instruction_;
+ CHECK_NE(nullptr, root);
+
+ return WrapUnique(
+ new HloComputation(name_, parameter_count, &instructions_, root));
+}
+
+HloComputation::HloComputation(
+ const string& name, int parameter_count,
+ std::vector<std::unique_ptr<HloInstruction>>* instructions,
+ HloInstruction* root_instruction)
+ : name_(name),
+ root_instruction_(root_instruction),
+ instruction_name_uniquer_(/*separator=*/".") {
+ param_instructions_.resize(parameter_count, nullptr);
+ bool root_found = false;
+ for (auto& instruction : *instructions) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ int64 param_no = instruction->parameter_number();
+ CHECK_GE(param_no, 0);
+ CHECK_LT(param_no, param_instructions_.size());
+ CHECK_EQ(nullptr, param_instructions_[param_no]);
+ param_instructions_[param_no] = instruction.get();
+ }
+ root_found |= instruction.get() == root_instruction_;
+ AddInstructionInternal(std::move(instruction));
+ }
+ CHECK(root_found);
+}
+
+HloInstruction* HloComputation::AddInstruction(
+ std::unique_ptr<HloInstruction> instruction) {
+ CHECK(instruction->opcode() != HloOpcode::kParameter)
+ << "Parameter instructions cannot be added to a computation after "
+ << "it has been built";
+ return AddInstructionInternal(std::move(instruction));
+}
+
+HloInstruction* HloComputation::AddInstructionInternal(
+ std::unique_ptr<HloInstruction> instruction) {
+ // Generate a unique name for the instruction.
+ instruction->set_name(
+ instruction_name_uniquer_.GetUniqueName(instruction->name()));
+ instruction->set_parent(this);
+ HloInstruction* pinst = instruction.get();
+ instruction_iterators_[pinst] =
+ instructions_.insert(instructions_.end(), std::move(instruction));
+ return pinst;
+}
+
+void HloComputation::RemoveInstructionAndUnusedOperands(
+ HloInstruction* instruction) {
+ CHECK_NE(root_instruction(), instruction);
+
+ CHECK_EQ(0, instruction->user_count());
+ CHECK_NE(instruction->opcode(), HloOpcode::kParameter)
+ << "Cannot remove parameter instructions";
+ std::queue<HloInstruction*> remove;
+ remove.push(instruction);
+ while (!remove.empty()) {
+ HloInstruction* item = remove.front();
+ remove.pop();
+ if (item->user_count() != 0 || item == root_instruction_ ||
+ item->opcode() == HloOpcode::kParameter) {
+ continue;
+ }
+ for (int i = 0; i < item->operand_count(); ++i) {
+ remove.push(item->mutable_operand(i));
+ }
+
+ // If an instruction has the same operand more than once, we must not remove
+ // it again.
+ RemoveInstruction(item);
+ }
+}
+
+bool HloComputation::RemoveInstructionIfFound(HloInstruction* instruction) {
+ CHECK_NE(instruction->opcode(), HloOpcode::kParameter)
+ << "Cannot remove parameter instructions";
+ CHECK_NE(root_instruction(), instruction) << "cannot remove root instruction";
+ CHECK_EQ(0, instruction->user_count())
+ << "instruction with users cannot be removed";
+
+ if (instruction_iterators_.count(instruction) == 0) {
+ return false;
+ }
+ auto inst_it = instruction_iterators_.at(instruction);
+ (*inst_it)->set_parent(nullptr);
+ instruction->DetachFromOperands();
+ instructions_.erase(inst_it);
+ return true;
+}
+
+void HloComputation::RemoveInstruction(HloInstruction* instruction) {
+ CHECK(RemoveInstructionIfFound(instruction))
+ << instruction->ToString() << " is not a member of computation "
+ << name();
+}
+
+void HloComputation::ReplaceUsesOfInstruction(
+ HloInstruction* instruction_to_replace, HloInstruction* instruction) {
+ instruction_to_replace->ReplaceAllUsesWith(instruction);
+ if (instruction_to_replace == root_instruction()) {
+ set_root_instruction(instruction);
+ }
+}
+
+void HloComputation::set_root_instruction(
+ HloInstruction* new_root_instruction) {
+ // The shape of the root (ignoring layout) is an invariant of the computation.
+ CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
+ root_instruction_->shape()))
+ << new_root_instruction->shape().ShortDebugString()
+ << " is incompatible with "
+ << root_instruction_->shape().ShortDebugString();
+ bool root_found = false;
+ for (auto& instruction : instructions_) {
+ if (new_root_instruction == instruction.get()) {
+ root_found = true;
+ break;
+ }
+ }
+ DCHECK(root_found);
+
+ root_instruction_ = new_root_instruction;
+}
+
+namespace {
+
+// Helper class which computes the post order of an expression rooted at a
+// particular instruction.
+class InstructionPostOrderer : public DfsHloVisitorWithDefault {
+ public:
+ // added_instructions is the set of instructions which have already been
+ // accounted for in the post order in previous invocations of
+ // GetOrder. Without this mechanism, instructions which are predecessors of
+ // multiple root instructions of the computation can be added to the post
+ // order more than once.
+ static std::list<HloInstruction*> GetOrder(
+ HloInstruction* root,
+ tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions) {
+ InstructionPostOrderer orderer(added_instructions);
+ TF_CHECK_OK(root->Accept(&orderer));
+ return std::move(orderer.post_order_);
+ }
+
+ private:
+ explicit InstructionPostOrderer(
+ tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions)
+ : added_instructions_(added_instructions) {}
+ ~InstructionPostOrderer() override {}
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ if (added_instructions_->count(hlo_instruction) == 0) {
+ post_order_.push_back(hlo_instruction);
+ added_instructions_->insert(hlo_instruction);
+ }
+ return Status::OK();
+ }
+
+ std::list<HloInstruction*> post_order_;
+ tensorflow::gtl::FlatSet<HloInstruction*>* added_instructions_;
+};
+
+// Helper which builds a post order of the HLO call graph.
+void ComputeComputationPostOrder(
+ HloComputation* computation,
+ tensorflow::gtl::FlatSet<HloComputation*>* visited,
+ std::list<HloComputation*>* post_order) {
+ if (visited->count(computation) > 0) {
+ return;
+ }
+
+ for (auto& instruction : computation->instructions()) {
+ for (auto& called_computation : instruction->MakeCalledComputationsSet()) {
+ ComputeComputationPostOrder(called_computation, visited, post_order);
+ }
+ }
+
+ visited->insert(computation);
+ post_order->push_back(computation);
+ return;
+}
+
+} // namespace
+
+std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ std::list<HloInstruction*> post_order;
+ std::list<HloInstruction*> trace_instructions;
+ tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
+ for (auto& instruction : instructions_) {
+ if (instruction->opcode() == HloOpcode::kTrace) {
+ // Trace instructions aren't handled by the DFS visitor. Add trace
+ // instructions to the post order at the end (necessarily they have no
+ // users).
+ trace_instructions.push_back(instruction.get());
+ } else if (instruction->users().empty()) {
+ post_order.splice(post_order.end(),
+ InstructionPostOrderer::GetOrder(instruction.get(),
+ &added_instructions));
+ }
+ }
+ post_order.splice(post_order.end(), trace_instructions);
+ CHECK_EQ(instructions_.size(), post_order.size())
+ << "number of instructions does not match post order size";
+ return post_order;
+}
+
+std::list<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
+ const {
+ tensorflow::gtl::FlatSet<HloComputation*> visited;
+ std::list<HloComputation*> post_order;
+
+ // To avoid special handling of this computation, cast away const of
+ // 'this'. 'this' is immediately removed from the post order after
+ // construction.
+ ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
+ &post_order);
+
+ // We don't want to include this computation in the post order.
+ CHECK_EQ(this, post_order.back());
+ post_order.pop_back();
+
+ return post_order;
+}
+
+string HloComputation::ToString() const {
+ std::ostringstream s;
+ s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
+ << " { \n";
+ for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
+ s << " " << instruction->ToString() << "\n";
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ for (const auto& fused_instruction : instruction->fused_instructions()) {
+ s << " " << fused_instruction->ToString() << "\n";
+ }
+ }
+ }
+ s << "}";
+ return s.str();
+}
+
+void HloComputation::FuseInstructionsInto(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ HloInstruction* fusion_instruction) {
+ CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
+ HloInstruction* root = instructions_to_fuse.front();
+ root->ReplaceAllUsesWith(fusion_instruction);
+ if (root == root_instruction()) {
+ set_root_instruction(fusion_instruction);
+ }
+ RemoveInstruction(root);
+ for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
+ HloInstruction* instruction = instructions_to_fuse[i];
+ fusion_instruction->FuseInstruction(instruction);
+ if (instruction->user_count() == 0) {
+ RemoveInstruction(instruction);
+ }
+ }
+}
+
+HloInstruction* HloComputation::CreateFusionInstruction(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ HloInstruction::FusionKind fusion_kind) {
+ HloInstruction* root = instructions_to_fuse.front();
+ HloInstruction* fusion_instruction = AddInstruction(
+ HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
+ FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
+ return fusion_instruction;
+}
+
+HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ HloInstruction::FusionKind fusion_kind, const Window& window,
+ const ConvolutionDimensionNumbers& conv_dnums) {
+ CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind ||
+ HloInstruction::FusionKind::kConvBackwardInput == fusion_kind);
+ HloInstruction* root = instructions_to_fuse.front();
+ HloInstruction* fusion_instruction =
+ AddInstruction(HloInstruction::CreateFusionForBackwardConvolution(
+ root->shape(), fusion_kind, window, conv_dnums, root));
+ FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
+ return fusion_instruction;
+}
+
+StatusOr<HloInstruction*> HloComputation::DeepCopyTuple(
+ HloInstruction* instruction) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(instruction->shape()));
+ std::vector<HloInstruction*> element_copies;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
+ ++i) {
+ HloInstruction* gte = AddInstruction(HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, i));
+ // Recurse to copy tuple elements. For array elements, insert a kCopy
+ // because GetTupleElement forwards a pointer to the tuple element buffer.
+ HloInstruction* element_copy;
+ if (ShapeUtil::IsTuple(gte->shape())) {
+ TF_ASSIGN_OR_RETURN(element_copy, DeepCopyTuple(gte));
+ } else {
+ element_copy = AddInstruction(
+ HloInstruction::CreateUnary(gte->shape(), HloOpcode::kCopy, gte));
+ }
+ element_copies.push_back(element_copy);
+ }
+
+ // Gather element copies into a tuple with a new Tuple instruction.
+ return AddInstruction(HloInstruction::CreateTuple(element_copies));
+}
+
+StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
+ HloInstruction* instruction) {
+ if (instruction->parent() != this) {
+ return FailedPrecondition(
+ "Can't deep copy instruction %s: instruction is not in computation %s",
+ instruction->name().c_str(), name().c_str());
+ }
+
+ // For tuple instructions, perform a deep copy. For array instructions, copy
+ // with a kCopy instruction.
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ return DeepCopyTuple(instruction);
+ } else if (ShapeUtil::IsArray(instruction->shape())) {
+ return AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ } else {
+ return FailedPrecondition(
+ "Can only copy array and tuple shaped instructions");
+ }
+}
+
+Status HloComputation::AddControlDependency(HloInstruction* predecessor,
+ HloInstruction* successor) {
+ TF_RET_CHECK(instruction_iterators_.count(predecessor) > 0);
+ TF_RET_CHECK(instruction_iterators_.count(successor) > 0);
+ successor->AddControlPredecessor(predecessor);
+ return Status::OK();
+}
+
+ProgramShape HloComputation::ComputeProgramShape() const {
+ ProgramShape program_shape;
+
+ for (auto* param_instruction : param_instructions_) {
+ *program_shape.add_parameters() = param_instruction->shape();
+ *program_shape.add_parameter_names() = param_instruction->parameter_name();
+ }
+ *program_shape.mutable_result() = root_instruction_->shape();
+
+ LayoutUtil::ClearLayout(&program_shape);
+ return program_shape;
+}
+
+bool HloComputation::operator==(const HloComputation& other) const {
+ std::set<std::pair<const HloInstruction*, const HloInstruction*>> visited;
+ std::function<bool(const HloInstruction*, const HloInstruction*)> eq =
+ [&visited, &eq](const HloInstruction* a, const HloInstruction* b) {
+ // If <a,b> are visited but not identical, the recursion should have
+ // been aborted. So, if <a,b> are visited at this point, they must be
+ // identical.
+ if (visited.count(std::make_pair(a, b)) > 0) return true;
+ visited.emplace(a, b);
+ return a->Identical(
+ *b, eq, [](const HloComputation* a, const HloComputation* b) {
+ return *a == *b;
+ });
+ };
+ return eq(root_instruction(), other.root_instruction());
+}
+
+void HloComputation::ReplaceWithNewInstruction(
+ HloInstruction* old_instruction,
+ std::unique_ptr<HloInstruction> new_instruction) {
+ ReplaceInstruction(old_instruction,
+ AddInstruction(std::move(new_instruction)));
+}
+
+void HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
+ HloInstruction* new_instruction) {
+ CHECK(ShapeUtil::Compatible(old_instruction->shape(),
+ new_instruction->shape()));
+ VLOG(10) << "transformed " << old_instruction->ToString() << " to "
+ << new_instruction->ToString();
+ ReplaceUsesOfInstruction(old_instruction, new_instruction);
+ RemoveInstructionAndUnusedOperands(old_instruction);
+}
+
+HloComputation::ReachabilityMap::ReachabilityMap(
+ const std::list<HloInstruction*>& all_instructions) {
+ const int n = all_instructions.size();
+ int next_id = 0;
+ for (const auto* hlo : all_instructions) {
+ ids_[hlo] = next_id;
+ next_id++;
+ }
+ DCHECK_EQ(n, ids_.size()); // instructions should be unique
+ matrix_.Reset(n * n);
+}
+
+void HloComputation::ReachabilityMap::SetReachable(const HloInstruction* a,
+ const HloInstruction* b) {
+ const int id_a = FindOrDie(ids_, a);
+ const int id_b = FindOrDie(ids_, b);
+ matrix_.set(id_a * ids_.size() + id_b);
+}
+
+bool HloComputation::ReachabilityMap::IsReachable(
+ const HloInstruction* a, const HloInstruction* b) const {
+ const int id_a = FindOrDie(ids_, a);
+ const int id_b = FindOrDie(ids_, b);
+ return matrix_.get(id_a * ids_.size() + id_b);
+}
+
+bool HloComputation::ReachabilityMap::IsConnected(
+ const HloInstruction* a, const HloInstruction* b) const {
+ const int id_a = FindOrDie(ids_, a);
+ const int id_b = FindOrDie(ids_, b);
+ return matrix_.get(id_a * ids_.size() + id_b) ||
+ matrix_.get(id_b * ids_.size() + id_a);
+}
+
+void HloComputation::ReachabilityMap::SetReachableAndTransitiveClosure(
+ const HloInstruction* a, const HloInstruction* b) {
+ const int id_a = FindOrDie(ids_, a);
+ const int id_b = FindOrDie(ids_, b);
+ const int n = ids_.size();
+ matrix_.set(id_a * n + id_b);
+
+ // Copy transitive set for b into entries for a
+ for (int i = 0; i < n; i++) {
+ if (matrix_.get(id_b * n + i)) {
+ matrix_.set(id_a * n + i);
+ }
+ }
+}
+
+std::unique_ptr<HloComputation::ReachabilityMap>
+HloComputation::ComputeTransitiveOperands() const {
+ const auto all = MakeInstructionPostOrder();
+ auto result = MakeUnique<HloComputation::ReachabilityMap>(all);
+
+ // Fill in the dependency bit matrix
+ for (const auto* hlo : all) {
+ for (const HloInstruction* operand : hlo->operands()) {
+ result->SetReachableAndTransitiveClosure(hlo, operand);
+ }
+ }
+ return result;
+}
+
+Status HloComputation::Accept(DfsHloVisitor* visitor) const {
+ // Visit all dead roots.
+ for (auto& instruction : instructions()) {
+ if (instruction->user_count() == 0 &&
+ instruction.get() != root_instruction()) {
+ // Call FinishVisit only at the end.
+ TF_RETURN_IF_ERROR(
+ instruction->Accept(visitor, /*call_finish_visit=*/false));
+ }
+ }
+ // Visit root instruction last.
+ return root_instruction()->Accept(visitor, /*call_finish_visit=*/true);
+}
+
+Status HloComputation::Accept(
+ const FunctionVisitor::VisitorFunction& visitor_func) const {
+ FunctionVisitor visitor(visitor_func);
+ return this->Accept(&visitor);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
new file mode 100644
index 0000000000..67df5797dc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -0,0 +1,300 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class HloModule;
+
+// Describes a computation at the HLO level.
+//
+// An HloComputation contains a directed acyclic graph of HLO instructions. The
+// computation has a single root instruction which produces the output of the
+// computation.
+class HloComputation {
+ public:
+ // Builder class for HloComputation.
+ class Builder {
+ public:
+ explicit Builder(const string& name)
+ : name_(name), last_added_instruction_(nullptr) {}
+
+ // Build and return an HloComputation. The parameter root_instruction
+ // specifies the already-added instruction to use as the root. If
+ // root_instruction is nullptr then use the last added instruction as the
+ // root.
+ std::unique_ptr<HloComputation> Build(
+ HloInstruction* root_instruction = nullptr);
+
+ HloInstruction* AddInstruction(
+ std::unique_ptr<HloInstruction> instruction) {
+ instructions_.push_back(std::move(instruction));
+ last_added_instruction_ = instructions_.back().get();
+ return last_added_instruction_;
+ }
+
+ private:
+ const string name_;
+ HloInstruction* last_added_instruction_;
+ std::vector<std::unique_ptr<HloInstruction>> instructions_;
+ };
+
+ // Add an instruction to the computation. The computation takes ownership of
+ // the instruction.
+ HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction);
+
+ // Remove an instruction from the computation. The instruction must have no
+ // users. Instruction is deallocated with this call.
+ void RemoveInstruction(HloInstruction* instruction);
+
+ // Remove an instruction from the computation and also transitively any
+ // operand that has no users post removing an instruction. The instruction
+ // must have no users. Instruction is deallocated with this call.
+ void RemoveInstructionAndUnusedOperands(HloInstruction* instruction);
+
+ // Replace all uses of "instruction_to_replace" with "instruction". Also, if
+ // instruction_to_replace is the root of this computation then the root is set
+ // to "instruction". Does not remove "instruction_to_replace".
+ void ReplaceUsesOfInstruction(HloInstruction* instruction_to_replace,
+ HloInstruction* instruction);
+
+ // Set the root of the computation to the given instruction. The instruction
+ // must have already been added to the computation and have the same shape as
+ // the result of the computation.
+ void set_root_instruction(HloInstruction* instruction);
+
+ // Return the root instruction of the computation. The root instruction is the
+ // instruction which produces the output of the computation.
+ HloInstruction* root_instruction() const { return root_instruction_; }
+
+ // Returns the number of parameters for this computation.
+ int64 num_parameters() const { return param_instructions_.size(); }
+
+ // Returns the parameter instruction for the given parameter number.
+ HloInstruction* parameter_instruction(int64 param_no) const {
+ CHECK_GE(param_no, 0);
+ CHECK_LT(param_no, param_instructions_.size());
+ return param_instructions_[param_no];
+ }
+
+ const std::vector<HloInstruction*>& parameter_instructions() const {
+ return param_instructions_;
+ }
+
+ const string& name() const { return name_; }
+
+ // Return a string representation of the computation.
+ string ToString() const;
+
+ const std::list<std::unique_ptr<HloInstruction>>& instructions() const {
+ return instructions_;
+ }
+
+ // Add a control dependency between the two instructions in this computation
+ // so that the 'predecessor' is visited before the 'successor' during the DFS
+ // traversal of the computation. Returns an error status if either of the
+ // given instructions does not belong to the current computation.
+ //
+ // This is used to enforce an additional ordering requirement that is not
+ // captured by normal data dependencies, such as ordering among Send or Recv
+ // operations to avoid deadlock.
+ Status AddControlDependency(HloInstruction* predecessor,
+ HloInstruction* successor);
+
+ // Compute and return a post-order of the instructions in the computation. In
+ // this order, definitions of values always appear before their uses.
+ std::list<HloInstruction*> MakeInstructionPostOrder() const;
+
+ // Computes and returns the mapping from HLO to its transitive operands.
+ class ReachabilityMap;
+ std::unique_ptr<ReachabilityMap> ComputeTransitiveOperands() const;
+
+ int64 instruction_count() const { return instructions_.size(); }
+
+ // Creates and returns a list of the embedded computations called by this
+ // computation. This includes all embedded computations called directly or
+ // transitively. The embedded computations are sorted such that if computation
+ // A calls computation B (eg, via a map instruction) then A will appear after
+ // B in the list.
+ std::list<HloComputation*> MakeEmbeddedComputationsList() const;
+
+ // Creates a fusion instruction containing the given instructions.
+ // `fusion_kind` indicates the type of the fusion, e.g., loop fusion or fusion
+ // into a library call. Instructions must be in reverse topological order
+ // (root of the fused expression first). Replaces all uses of the original
+ // root instruction with the fusion instruction. The original instructions are
+ // removed if they have no uses after fusion (this is necessarily true for at
+ // least the root).
+ HloInstruction* CreateFusionInstruction(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ HloInstruction::FusionKind fusion_kind);
+
+ // Creates a fusion instruction that represents a backward convolution. This
+ // is similar to CreateFusionInstruction but takes window and conv_dnums which
+ // indicate the window and convolution dimension numbers of the backward
+ // convolution.
+ HloInstruction* CreateFusionInstructionForBackwardConvolution(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ HloInstruction::FusionKind fusion_kind, const Window& window,
+ const ConvolutionDimensionNumbers& conv_dnums);
+
+ // Create a deep copy of the given instruction and return the instruction
+ // producing the copied result. All instructions performing the copy are added
+ // to the computation. For array-shaped values, this method trivially returns
+ // a kCopy instruction. For tuple-shaped instructions, the copy is performed
+ // with a series of kGetTupleElement and kTuple instructions.
+ StatusOr<HloInstruction*> DeepCopyInstruction(HloInstruction* instruction);
+
+ // Computes and returns the ProgramShape of this computation (shape of
+ // parameters and result without layout).
+ ProgramShape ComputeProgramShape() const;
+
+ // Return whether `*this` and `other` are functionally equivalent.
+ bool operator==(const HloComputation& other) const;
+
+ // Replaces old instruction with newly created instruction. Removes old
+ // instruction from computation. Updates uses and root instruction.
+ void ReplaceWithNewInstruction(
+ HloInstruction* old_instruction,
+ std::unique_ptr<HloInstruction> new_instruction);
+
+ // Replace old instruction with new instruction. Updates uses and root
+ // instruction. Removes old instruction from computation. Precondition:
+ // old_instruction and new_instruction must have the compatible shapes.
+ void ReplaceInstruction(HloInstruction* old_instruction,
+ HloInstruction* new_instruction);
+
+ // Set/get the module containing this computation.
+ void set_parent(HloModule* module) { parent_ = module; }
+ const HloModule* parent() const { return parent_; }
+
+ // Visit every node in the computation in DFS post-order with the given
+ // visitor. This is similar to calling HloInstruction::Accept on the root of
+ // the computation except this method also visits instructions not reachable
+ // via the root. The root instruction of the computation is visited last, and
+ // the visitor's FinishVisit method is called once upon completion (with the
+ // root instruction as the argument).
+ Status Accept(DfsHloVisitor* visitor) const;
+
+ // Same as Accept() above, but the visitor is given as a function.
+ Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const;
+
+ private:
+ explicit HloComputation(
+ const string& name, int parameter_count,
+ std::vector<std::unique_ptr<HloInstruction>>* instructions,
+ HloInstruction* root_instruction);
+
+ // Internal helper for adding instructions.
+ HloInstruction* AddInstructionInternal(
+ std::unique_ptr<HloInstruction> instruction);
+
+ // Remove an instruction from the computation if found. The instruction must
+ // have no users. Instruction is deallocated with this call.
+ // Return whether instruction was found and removed.
+ bool RemoveInstructionIfFound(HloInstruction* instruction);
+
+ // Fuses HLOs in instructions_to_fuse into fusion_instruction.
+ //
+ // Pre-condition: fusion_instruction's opcode is kFusion.
+ void FuseInstructionsInto(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_fuse,
+ HloInstruction* fusion_instruction);
+
+ // Internal helper for copying a tuple value. Creates and returns a deep copy
+ // of the given instruction. The given instruction must be tuple-shaped.
+ StatusOr<HloInstruction*> DeepCopyTuple(HloInstruction* instruction);
+
+ const string name_;
+ HloInstruction* root_instruction_;
+
+ // Module containing this computation.
+ HloModule* parent_ = nullptr;
+
+ // Store instructions in std::list as they can be added and removed
+ // arbitrarily and we want a stable iteration order. Keep a map from
+ // instruction pointer to location in the list for fast lookup.
+ using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
+ InstructionList instructions_;
+ std::unordered_map<const HloInstruction*, InstructionList::iterator>
+ instruction_iterators_;
+
+ std::vector<HloInstruction*> param_instructions_;
+
+ // Unique name generator for instruction identifiers. Instruction names should
+ // be unique per computation and this is enforced when instructions are added
+ // to the computation.
+ NameUniquer instruction_name_uniquer_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HloComputation);
+};
+
+class HloComputation::ReachabilityMap {
+ public:
+ // Sets up an empty reachable matrix for the full set of
+ // instructions specified in "all_instructions"
+ explicit ReachabilityMap(const std::list<HloInstruction*>& all_instructions);
+ // Sets entry so that IsReachable(a, b) will return true
+ void SetReachable(const HloInstruction* a, const HloInstruction* b);
+
+ // Sets IsReachable(a_inst, b_inst) as well as IsReachable(a_inst, trans)
+ // for all "trans" s.t. "IsReachable(b_inst, trans)" is true
+ void SetReachableAndTransitiveClosure(const HloInstruction* a_inst,
+ const HloInstruction* b_inst);
+
+ // Returns true if "b" is reachable from "a"
+ bool IsReachable(const HloInstruction* a, const HloInstruction* b) const;
+
+ // Returns true if "b" is reachable from "a" or "a" is reachable from "b"
+ bool IsConnected(const HloInstruction* a, const HloInstruction* b) const;
+
+ private:
+ friend class HloComputation;
+
+ // dense id assignment from HloInstruction* to number
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> ids_;
+ // matrix_(a,b) is true iff b is reachable from a
+ tensorflow::core::Bitmap matrix_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COMPUTATION_H_
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
new file mode 100644
index 0000000000..1e0d09b72c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -0,0 +1,311 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+
+#include <set>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+
+namespace {
+
+class HloComputationTest : public HloTestBase {
+ protected:
+ HloComputationTest() {}
+
+ // Create a computation which takes a scalar and returns its negation.
+ std::unique_ptr<HloComputation> CreateNegateComputation() {
+ auto builder = HloComputation::Builder("Negate");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
+ return builder.Build();
+ }
+
+ // Creates a computation which calls map with the given computation.
+ std::unique_ptr<HloComputation> CreateMapComputation(
+ HloComputation* map_computation) {
+ auto builder = HloComputation::Builder("Map");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map_computation));
+ return builder.Build();
+ }
+
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+};
+
+TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
+ auto negate_computation = CreateNegateComputation();
+ EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
+}
+
+TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
+ // Create computation which calls one other computation.
+ auto negate_computation = CreateNegateComputation();
+ auto map_computation = CreateMapComputation(negate_computation.get());
+ EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
+ EXPECT_EQ(map_computation->MakeEmbeddedComputationsList().front(),
+ negate_computation.get());
+}
+
+TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
+ // Create computations with a diamond-shaped callgraph.
+ auto negate_computation = CreateNegateComputation();
+ auto map1_computation = CreateMapComputation(negate_computation.get());
+ auto map2_computation = CreateMapComputation(negate_computation.get());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32_, "param0"));
+ auto map1 = builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get()));
+ auto map2 = builder.AddInstruction(
+ HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get()));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
+ auto computation = builder.Build();
+
+ auto embedded_computations = computation->MakeEmbeddedComputationsList();
+ EXPECT_EQ(3, embedded_computations.size());
+ // GetEmbeddedComputations returns a post order of the embedded computations,
+ // so the negate computation must come first.
+ EXPECT_EQ(negate_computation.get(), *embedded_computations.begin());
+ EXPECT_MATCH(testing::ListToVec<HloComputation*>(embedded_computations),
+ testing::UnorderedMatcher<HloComputation*>(
+ negate_computation.get(), map1_computation.get(),
+ map2_computation.get()));
+}
+
+TEST_F(HloComputationTest, PostOrderSingleton) {
+ // Test GetInstructionPostOrder for a computation with one instruction.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto computation = builder.Build();
+
+ EXPECT_EQ(computation->MakeInstructionPostOrder().front(), constant);
+}
+
+TEST_F(HloComputationTest, PostOrderSimple) {
+ // Test GetInstructionPostOrder for a computation with a chain of
+ // instructions.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
+ auto negate2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
+ auto computation = builder.Build();
+
+ EXPECT_MATCH(
+ testing::ListToVec<HloInstruction*>(
+ computation->MakeInstructionPostOrder()),
+ testing::OrderedMatcher<HloInstruction*>(constant, negate1, negate2));
+}
+
+TEST_F(HloComputationTest, PostOrderTrace) {
+ // Test GetInstructionPostOrder for a computation with a trace instruction.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
+ auto trace =
+ builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
+ auto negate2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
+ auto computation = builder.Build();
+
+ // Trace instructions should be at the end of the sort.
+ EXPECT_MATCH(testing::ListToVec<HloInstruction*>(
+ computation->MakeInstructionPostOrder()),
+ testing::OrderedMatcher<HloInstruction*>(constant, negate1,
+ negate2, trace));
+}
+
+TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
+ // Test GetInstructionPostOrder for a computation with multiple instructions
+ // which are not connected.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant4 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto computation = builder.Build();
+
+ EXPECT_MATCH(testing::ListToVec<HloInstruction*>(
+ computation->MakeInstructionPostOrder()),
+ testing::UnorderedMatcher<HloInstruction*>(
+ constant1, constant2, constant3, constant4));
+}
+
+TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
+ // Test GetInstructionPostOrder for a computation with multiple instructions
+ // which are not connected.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32_, HloOpcode::kAdd, constant1, constant2));
+ auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32_, HloOpcode::kAdd, constant2, constant3));
+ auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32_, HloOpcode::kAdd, constant1, constant3));
+ auto computation = builder.Build();
+
+ auto post_order = computation->MakeInstructionPostOrder();
+ EXPECT_EQ(6, post_order.size());
+ EXPECT_MATCH(testing::ListToVec<HloInstruction*>(post_order),
+ testing::UnorderedMatcher<HloInstruction*>(
+ constant1, constant2, constant3, add1, add2, add3));
+}
+
+TEST_F(HloComputationTest, VisitWithMultipleRoots) {
+ // Test that Accept visits all instructions in the computation even if the
+ // computation has multiple roots (dead code).
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ // Add three disconnected add expressions.
+ builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ constant1, constant2));
+ builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ constant2, constant3));
+ builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ constant1, constant3));
+ auto computation = builder.Build();
+
+ // Visitor which keeps track of which instructions have been visited.
+ class TestVisitor : public DfsHloVisitorWithDefault {
+ public:
+ explicit TestVisitor(HloComputation* computation)
+ : computation_(computation) {}
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ EXPECT_EQ(0, visited_set_.count(hlo_instruction));
+ visited_set_.insert(hlo_instruction);
+ last_visited_ = hlo_instruction;
+ return Status::OK();
+ }
+
+ Status FinishVisit(HloInstruction* root) override {
+ EXPECT_EQ(computation_->root_instruction(), root);
+ ++finish_visit_calls_;
+ return Status::OK();
+ }
+
+ HloComputation* computation_;
+ std::set<HloInstruction*> visited_set_;
+ int64 finish_visit_calls_ = 0;
+ HloInstruction* last_visited_ = nullptr;
+ };
+
+ TestVisitor visitor(computation.get());
+ EXPECT_IS_OK(computation->Accept(&visitor));
+
+ EXPECT_EQ(6, visitor.visited_set_.size());
+ EXPECT_EQ(1, visitor.finish_visit_calls_);
+ EXPECT_EQ(computation->root_instruction(), visitor.last_visited_);
+}
+
+TEST_F(HloComputationTest, DeepCopyArray) {
+ // Test that DeepCopyInstruction properly copies an array.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+ auto computation = builder.Build();
+
+ auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
+
+ EXPECT_EQ(HloOpcode::kCopy, copy->opcode());
+ EXPECT_EQ(constant, copy->operand(0));
+}
+
+TEST_F(HloComputationTest, DeepCopyTuple) {
+ // Test that DeepCopyInstruction properly copies a tuple.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+
+ auto computation = builder.Build();
+
+ auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
+
+ EXPECT_EQ(HloOpcode::kTuple, tuple_copy->opcode());
+ EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(0)->opcode());
+ const HloInstruction* gte0 = tuple_copy->operand(0)->operand(0);
+ EXPECT_EQ(HloOpcode::kGetTupleElement, gte0->opcode());
+ EXPECT_EQ(0, gte0->tuple_index());
+ EXPECT_EQ(tuple, gte0->operand(0));
+
+ EXPECT_EQ(HloOpcode::kCopy, tuple_copy->operand(1)->opcode());
+ const HloInstruction* gte1 = tuple_copy->operand(1)->operand(0);
+ EXPECT_EQ(HloOpcode::kGetTupleElement, gte1->opcode());
+ EXPECT_EQ(1, gte1->tuple_index());
+ EXPECT_EQ(tuple, gte1->operand(0));
+}
+
+TEST_F(HloComputationTest, CycleDetection) {
+ // Test whether the visitor can detect cycles in the graph.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto negate = builder.AddInstruction(
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
+ auto computation = builder.Build();
+
+ // Add a control dependency to create a cycle.
+ ASSERT_IS_OK(computation->AddControlDependency(add, negate));
+
+ const auto visitor = [](HloInstruction* instruction) { return Status::OK(); };
+ auto visit_status = computation->Accept(visitor);
+ ASSERT_FALSE(visit_status.ok());
+ ASSERT_MATCH(visit_status.error_message(),
+ testing::ContainsRegex("cycle is detecte"));
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
new file mode 100644
index 0000000000..1b2a955f39
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -0,0 +1,350 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+
+#include <cmath>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) {
+ const auto& shape = hlo_instruction->shape();
+ // For element-wise operations, the number of computations is the same as the
+ // number of elements in the output shape.
+ auto computation_count = ShapeUtil::ElementsIn(shape);
+ auto opcode = hlo_instruction->opcode();
+ // We treat the two opcodes (kExp, kPower) as transcendental operations.
+ if (opcode == HloOpcode::kExp || opcode == HloOpcode::kPower) {
+ transcendental_count_ += computation_count;
+ } else {
+ // Note: transcendental operations are considered a separate category from
+ // FLOPs.
+ hlo_to_flop_count_[hlo_instruction] = computation_count;
+ flop_count_ += computation_count;
+ }
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo,
+ HloOpcode opcode,
+ HloInstruction* operand) {
+ return HandleElementwiseOp(hlo);
+}
+
+Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo,
+ HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseOp(hlo);
+}
+
+Status HloCostAnalysis::HandleCompare(HloInstruction* compare, HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs) {
+ return HandleElementwiseOp(compare);
+}
+
+Status HloCostAnalysis::HandleClamp(HloInstruction* clamp,
+ HloInstruction* min_instruction,
+ HloInstruction* arg_instruction,
+ HloInstruction* max_instruction) {
+ return HandleElementwiseOp(clamp);
+}
+
+Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleConstant(HloInstruction* constant,
+ const Literal& literal) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleSelect(HloInstruction* select,
+ HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleReverse(HloInstruction* reverse,
+ HloInstruction* operand_instruction) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleSlice(HloInstruction* slice,
+ HloInstruction* operand_instruction) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleDynamicSlice(
+ HloInstruction* slice,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleDynamicUpdateSlice(
+ HloInstruction* dynamic_update, HloInstruction* operand,
+ HloInstruction* update, HloInstruction* start_indices) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) {
+ flop_count_ += ShapeUtil::ElementsIn(operand->shape());
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleCopy(HloInstruction* copy,
+ HloInstruction* operand) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleDot(HloInstruction* dot,
+ HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction) {
+ // We count an FMA operation as 2 floating point operations.
+ // Multiplying the sizes of lhs, rhs, and result produces the square of the
+ // number of FMAs during the computation.
+ auto fma_count = std::sqrt(
+ static_cast<double>(ShapeUtil::ElementsIn(lhs_instruction->shape())) *
+ ShapeUtil::ElementsIn(rhs_instruction->shape()) *
+ ShapeUtil::ElementsIn(dot->shape()));
+ flop_count_ += 2 * fma_count;
+ hlo_to_flop_count_[dot] = 2 * fma_count;
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleInfeed(HloInstruction* infeed) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleMap(
+ HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) {
+ // Compute the cost of the user function.
+ HloInstruction* function_instruction = function->root_instruction();
+ HloCostAnalysis visitor;
+ TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
+
+ // Compute the cost of all elements for this Map operation.
+ auto element_count = ShapeUtil::ElementsIn(map->shape());
+ flop_count_ += element_count * visitor.flop_count();
+ transcendental_count_ += element_count * visitor.transcendental_count();
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleReduce(
+ HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
+ // Compute the cost of the user function.
+ HloInstruction* function_instruction = function->root_instruction();
+ HloCostAnalysis visitor;
+ TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
+
+ // Compute the cost of all elements for this Reduce operation.
+ auto reduction_count = ShapeUtil::ElementsIn(arg->shape()) -
+ ShapeUtil::ElementsIn(reduce->shape());
+ flop_count_ += reduction_count * visitor.flop_count();
+ transcendental_count_ += reduction_count * visitor.transcendental_count();
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window,
+ HloInstruction* operand,
+ const Window& window,
+ HloComputation* function) {
+ // Compute the cost of the user function.
+ HloInstruction* function_instruction = function->root_instruction();
+ HloCostAnalysis visitor;
+ TF_RETURN_IF_ERROR(function_instruction->Accept(&visitor));
+
+ // Compute the cost of all elements for this ReduceWindow operation. For each
+ // output element, (window_size - 1) number of user computations are applied.
+ auto output_size = ShapeUtil::ElementsIn(reduce_window->shape());
+ int64 window_size = 1;
+ for (const auto& dimension : window.dimensions()) {
+ window_size *= dimension.size();
+ }
+ flop_count_ += output_size * (window_size - 1) * visitor.flop_count();
+ transcendental_count_ +=
+ output_size * (window_size - 1) * visitor.transcendental_count();
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) {
+ // Compute the cost of the select and scatter function.
+ HloInstruction* select = instruction->select()->root_instruction();
+ HloCostAnalysis select_visitor;
+ TF_RETURN_IF_ERROR(select->Accept(&select_visitor));
+ HloInstruction* scatter = instruction->scatter()->root_instruction();
+ HloCostAnalysis scatter_visitor;
+ TF_RETURN_IF_ERROR(scatter->Accept(&scatter_visitor));
+
+ // Compute the cost of all elements for this operation. For each scatter
+ // source element, (window_size - 1) number of select computations and 1
+ // scatter computation are applied.
+ const auto source = instruction->operand(1);
+ const auto source_element_count = ShapeUtil::ElementsIn(source->shape());
+ int64 window_size = 1;
+ for (const auto& dimension : instruction->window().dimensions()) {
+ window_size *= dimension.size();
+ }
+ flop_count_ +=
+ source_element_count * ((window_size - 1) * select_visitor.flop_count() +
+ scatter_visitor.flop_count());
+ transcendental_count_ +=
+ source_element_count *
+ ((window_size - 1) * select_visitor.transcendental_count() +
+ scatter_visitor.transcendental_count());
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleBroadcast(HloInstruction* broadcast) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandlePad(HloInstruction* pad) { return Status::OK(); }
+
+Status HloCostAnalysis::HandleSend(HloInstruction* send) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleRecv(HloInstruction* recv) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleReshape(HloInstruction* reshape) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleTranspose(HloInstruction* transpose) {
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution,
+ HloInstruction* lhs_instruction,
+ HloInstruction* rhs_instruction,
+ const Window& window) {
+ const auto& dnums = convolution->convolution_dimension_numbers();
+ const int64 output_features =
+ convolution->shape().dimensions(dnums.feature_dimension());
+
+ // For each output element, we do one fma per element in the
+ // kernel at some given output feature index.
+ const int64 fmas_per_output_element =
+ ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features;
+ const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape());
+ const double hlo_flop_count = static_cast<double>(output_elements) *
+ fmas_per_output_element * kFmaFlops;
+ flop_count_ += hlo_flop_count;
+ hlo_to_flop_count_[convolution] = hlo_flop_count;
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) {
+ // We assume 2 replicas, so that each output element is the sum of two input
+ // elements.
+ //
+ // TODO(b/33004697): Compute correct cost here, taking the actual number of
+ // replicas into account.
+ const double hlo_flop_count = ShapeUtil::ElementsIn(crs->shape());
+ flop_count_ += hlo_flop_count;
+ hlo_to_flop_count_[crs] = hlo_flop_count;
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleRng(HloInstruction* random,
+ RandomDistribution distribution) {
+ // TODO(b/26346211): Implement better estimates for the RNG cost, since the
+ // cost changes with the implementation and the distribution. For now, assume
+ // the cost of each RNG is same as a transcendental operation.
+ transcendental_count_ += ShapeUtil::ElementsIn(random->shape());
+ return Status::OK();
+}
+
+Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) {
+ // Fusion instruction itself does not contribute to computation.
+ return fusion->fused_expression_root()->Accept(this);
+}
+
+Status HloCostAnalysis::HandleCall(
+ HloInstruction* call, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) {
+ return Unimplemented("call");
+}
+
+Status HloCostAnalysis::HandleCustomCall(
+ HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) {
+ return Unimplemented("custom-call");
+}
+
+Status HloCostAnalysis::HandleSort(HloInstruction* sort,
+ HloInstruction* operand_instruction) {
+ // The cost of sort is implementation dependent, so cannot determine at HLO
+ // level. Maybe just assume the comparison based N*log(N) sorting?
+ // TODO(b/26346211): Implement the cost model for sort.
+ return Unimplemented("HandleSort");
+}
+
+Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while,
+ HloInstruction* init,
+ HloComputation* condition,
+ HloComputation* body) {
+ // Since the number of iterations of the while node is not statically
+ // determined, we cannot analyze the computation cost of a while node.
+ // TODO(b/26346211): Add cost analysis for while node.
+ return Unimplemented("HandleWhile");
+}
+
+Status HloCostAnalysis::FinishVisit(HloInstruction* root) {
+ return Status::OK();
+}
+
+double HloCostAnalysis::hlo_to_flop_count(const HloInstruction& hlo) const {
+ auto it = hlo_to_flop_count_.find(&hlo);
+ return it == hlo_to_flop_count_.end() ? 0.0 : it->second;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
new file mode 100644
index 0000000000..8bed07c07e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -0,0 +1,147 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// HloCostAnalysis traverses an HLO graph and calculates the amount of
+// computations required for the graph. Each HLO instruction handler provides
+// the computation cost of the instruction, and the values are accumulated
+// during the traversal for the entire graph. We treat normal floating point
+// operations separately from transcendental operations.
+class HloCostAnalysis : public DfsHloVisitor {
+ public:
+ HloCostAnalysis() = default;
+ ~HloCostAnalysis() override = default;
+
+ Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode,
+ HloInstruction* operand) override;
+ Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs) override;
+ Status HandleConstant(HloInstruction* constant,
+ const Literal& literal) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+ Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) override;
+ Status HandleCompare(HloInstruction* compare, HloOpcode opcode,
+ HloInstruction* lhs, HloInstruction* rhs) override;
+ Status HandleClamp(HloInstruction* clamp, HloInstruction* min,
+ HloInstruction* arg, HloInstruction* max) override;
+ Status HandleConcatenate(
+ HloInstruction* concatenate,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleSend(HloInstruction* send) override;
+ Status HandleRecv(HloInstruction* recv) override;
+ Status HandleConvert(HloInstruction* convert,
+ HloInstruction* operand) override;
+ Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
+ HloInstruction* rhs) override;
+ Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs,
+ HloInstruction* rhs, const Window& window) override;
+ Status HandleCrossReplicaSum(HloInstruction* crs) override;
+ Status HandleInfeed(HloInstruction* infeed) override;
+ Status HandleRng(HloInstruction* random,
+ RandomDistribution distribution) override;
+ Status HandleReverse(HloInstruction* reverse,
+ HloInstruction* operand) override;
+ Status HandleSort(HloInstruction* sort, HloInstruction* operand) override;
+ Status HandleParameter(HloInstruction* parameter) override;
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function_handle) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleCall(HloInstruction* call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) override;
+ Status HandleCustomCall(HloInstruction* custom_call,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) override;
+ Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override;
+ Status HandleDynamicSlice(
+ HloInstruction* slice,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice,
+ HloInstruction* operand,
+ HloInstruction* update,
+ HloInstruction* start_indices) override;
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleMap(
+ HloInstruction* map,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) override;
+ Status HandleReduceWindow(HloInstruction* reduce_window,
+ HloInstruction* operand, const Window& window,
+ HloComputation* function) override;
+ Status HandleSelectAndScatter(HloInstruction* instruction) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleBroadcast(HloInstruction* broadcast) override;
+ Status HandlePad(HloInstruction* pad) override;
+ Status HandleReshape(HloInstruction* reshape) override;
+ Status HandleTranspose(HloInstruction* transpose) override;
+ Status HandleWhile(HloInstruction* xla_while, HloInstruction* init,
+ HloComputation* condition, HloComputation* body) override;
+ Status FinishVisit(HloInstruction* root) override;
+
+ // Returns the amount of computations in the graph.
+ double flop_count() { return flop_count_; }
+ double transcendental_count() { return transcendental_count_; }
+
+ // Resolves the provided HLO instruction to a flop count, or 0 if the HLO was
+ // not found to have a flop count in the analysis.
+ double hlo_to_flop_count(const HloInstruction& hlo) const;
+
+ private:
+ // An FMA counts as two floating point operations in these analyses.
+ static constexpr int64 kFmaFlops = 2;
+
+ // Utility function to handle all element-wise operations.
+ Status HandleElementwiseOp(HloInstruction* hlo_instruction);
+
+ // Mapping from HLO instructions to the flop count we computed for them in the
+ // course of the graph analysis.
+ std::map<const HloInstruction*, double> hlo_to_flop_count_;
+
+ // The number of floating point operations in the graph.
+ double flop_count_ = 0;
+
+ // The number of transcendental operations in the graph.
+ double transcendental_count_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HloCostAnalysis);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_COST_ANALYSIS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
new file mode 100644
index 0000000000..776d40ac4d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc
@@ -0,0 +1,337 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/user_computation.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+
+namespace xla {
+namespace {
+
+// This test suite tests the HLO cost analysis by first building a computation
+// using the client computation builder and running the HloCostAnalysis that
+// returns the number of floating point and transcendental operations in the
+// graph. We test both individual HLO operations as well as a mixed graph.
+class HloCostAnalysisTest : public ::testing::Test {
+ protected:
+ HloCostAnalysisTest()
+ : client_(ClientLibrary::LocalClientOrDie()),
+ // Accessing service instance is required for the unit tests to enable
+ // whitebox acccesses to the user computation built from the client,
+ // as shown in the BuildHloGraph functions below.
+ service_(static_cast<Service*>(ClientLibrary::GetXlaService(
+ static_cast<LocalClient*>(client_)->platform()))),
+ computation_tracker_(service_->computation_tracker()) {
+ // Create a computation for a unary user function: x => exp(x + 0.5)
+ {
+ ComputationBuilder builder(client_, "add_and_exp");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto half = builder.ConstantR0<float>(0.5);
+ builder.Exp(builder.Add(x, half));
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ add_and_exp_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a binary user function: (x, y) => x + y
+ {
+ ComputationBuilder builder(client_, "add");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Add(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ add_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a sigmoid function: x => 1 / (1 + exp(-x))
+ {
+ ComputationBuilder builder(client_, "sigmoid");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = builder.ConstantR0<float>(1.0);
+ builder.Div(one, builder.Add(one, builder.Exp(builder.Neg(x))));
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ sigmoid_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a binary max function: (x, y) => max (x, y)
+ {
+ ComputationBuilder builder(client_, "max");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Max(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ max_ = computation_status.ConsumeValueOrDie();
+ }
+
+ // Create a computation for a binary GT function: (x, y) => x > y
+ {
+ ComputationBuilder builder(client_, "gt");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Gt(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ gt_ = computation_status.ConsumeValueOrDie();
+ }
+ }
+
+ // Build HLO graph from the given builder and return the HLO module.
+ std::unique_ptr<HloModule> BuildHloGraph(ComputationBuilder* builder) {
+ auto computation_status = builder->Build();
+ TF_CHECK_OK(computation_status.status());
+ auto computation = computation_status.ConsumeValueOrDie();
+ auto user_computation_status =
+ computation_tracker_.Resolve(computation.handle());
+ TF_CHECK_OK(user_computation_status.status());
+ auto user_computation = user_computation_status.ConsumeValueOrDie();
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+ return std::move(
+ computation_tracker_.BuildHloModule(versioned_handle).ValueOrDie());
+ }
+
+ Client* client_;
+ Service* service_;
+ const ComputationTracker& computation_tracker_;
+
+ // User computations used for higher order operations (e.g., Map, Reduce).
+ Computation add_;
+ Computation add_and_exp_;
+ Computation sigmoid_;
+ Computation max_;
+ Computation gt_;
+};
+
+TEST_F(HloCostAnalysisTest, MatrixMultiply) {
+ ComputationBuilder builder(client_, "matrix_multiply");
+ auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "lhs");
+ auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 30}), "rhs");
+ auto result = builder.Dot(lhs, rhs);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Check the number of computations returned from the analysis (1500 FMAs).
+ EXPECT_EQ(analysis.flop_count(), 2 * 10 * 30 * 5);
+}
+
+TEST_F(HloCostAnalysisTest, Map) {
+ ComputationBuilder builder(client_, "map");
+ auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in");
+ auto result = builder.Map({input}, add_and_exp_);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // add contributes to 10 flops and exp contributes to 10 transcendental ops.
+ EXPECT_EQ(analysis.flop_count(), 10);
+ EXPECT_EQ(analysis.transcendental_count(), 10);
+}
+
+TEST_F(HloCostAnalysisTest, Convolution) {
+ ComputationBuilder builder(client_, "convolution");
+ auto input = builder.Parameter(
+ 0, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/10,
+ /*x_dim=*/20}),
+ "input");
+ auto kernel = builder.Parameter(
+ 1, ShapeUtil::MakeShape(F32, {/*p_dim=*/1, /*z_dim=*/1, /*y_dim=*/3,
+ /*x_dim=*/3}),
+ "kernel");
+ auto result = builder.Conv(input, kernel, {1, 1}, Padding::kValid);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Output shape is [1x1x8x18] and each output element requires (3x3)
+ // FMAs and one FMA is 2 flops.
+ EXPECT_EQ(analysis.flop_count(), 8 * 18 * 2 * 3 * 3);
+}
+
+TEST_F(HloCostAnalysisTest, Reduce) {
+ ComputationBuilder builder(client_, "reduce");
+ auto input =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ auto result =
+ builder.Reduce(input, builder.ConstantR0<float>(0.0f), add_, {1});
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Subtracting the output size from the input size gives the number of
+ // reduction operations performed.
+ EXPECT_EQ(analysis.flop_count(), 10 * 20 - 10);
+}
+
+TEST_F(HloCostAnalysisTest, ReduceWindow) {
+ ComputationBuilder builder(client_, "reduce_window");
+ auto input =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ auto result = builder.ReduceWindow(input, builder.ConstantR0<float>(0), add_,
+ {4, 5}, {4, 5}, Padding::kValid);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Each of [2x4] output elements are generated from reducing [4x5] elements.
+ EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1));
+}
+
+TEST_F(HloCostAnalysisTest, SelectAndScatter) {
+ ComputationBuilder builder(client_, "select_and_scatter");
+ auto operand =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 20}), "input");
+ auto source =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 4}), "source");
+ auto result =
+ builder.SelectAndScatter(operand, gt_, {4, 5}, {4, 5}, Padding::kValid,
+ source, builder.ConstantR0<float>(0), add_);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // Each of [2x4] source elements computes its destination from reducing [4x5]
+ // elements followed by the scatter computation.
+ EXPECT_EQ(analysis.flop_count(), 2 * 4 * (4 * 5 - 1 + 1));
+}
+
+TEST_F(HloCostAnalysisTest, Broadcast) {
+ ComputationBuilder b(client_, "broadcast");
+ b.Broadcast(b.ConstantR0<float>(42), {10, 7});
+ auto hlo_module = BuildHloGraph(&b);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+ EXPECT_EQ(analysis.flop_count(), 0);
+}
+
+// Calculates the computation cost of a graph with more than one HLO node.
+TEST_F(HloCostAnalysisTest, FullyConnectedForward) {
+ ComputationBuilder builder(client_, "fully_connected_forward");
+ auto input =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {10, 5}), "input");
+ auto weight =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {5, 20}), "weight");
+ auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias");
+ // sigmoid(input * weight + bias)
+ auto result = builder.Map(
+ {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_);
+
+ // Run HLO cost analysis.
+ auto hlo_module = BuildHloGraph(&builder);
+ HloCostAnalysis analysis;
+ ASSERT_IS_OK(
+ hlo_module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ // 1000 FMAs from matrix multiplication, 200 flops from bias addition,
+ // 600 flops from sigmoid, and 200 transcendental ops from sigmoid.
+ EXPECT_EQ(analysis.flop_count(), 2 * 1000 + 200 + 3 * 200);
+ EXPECT_EQ(analysis.transcendental_count(), 200);
+}
+
+TEST_F(HloCostAnalysisTest, MatmulAndConvolutionCanBeTheSameComputation) {
+ HloCostAnalysis conv_analysis;
+ {
+ ComputationBuilder builder(client_, "conv_looking_matmul");
+ auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
+ "input");
+ auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64, 1, 1}),
+ "weights");
+ builder.Conv(lhs, rhs, {1, 1}, Padding::kSame);
+ auto hlo_module = BuildHloGraph(&builder);
+ ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
+ &conv_analysis));
+ }
+
+ HloCostAnalysis matmul_analysis;
+ {
+ ComputationBuilder builder(client_, "matmul");
+ auto lhs =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {64, 64}), "input");
+ auto rhs =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {64, 64}), "weights");
+ builder.Dot(lhs, rhs);
+ auto hlo_module = BuildHloGraph(&builder);
+ ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
+ &matmul_analysis));
+ }
+
+ EXPECT_EQ(conv_analysis.flop_count(), matmul_analysis.flop_count());
+}
+
+// Note that we still expect that any given operation won't overflow 2^64 FLOPs,
+// just that the sum total may.
+TEST_F(HloCostAnalysisTest, TotalOverflowsInt64) {
+ HloCostAnalysis matmul_analysis;
+ {
+ ComputationBuilder builder(client_, "matmul");
+ auto lhs = builder.Parameter(0, ShapeUtil::MakeShape(F32, {1, 1LL << 62}),
+ "input");
+ auto rhs = builder.Parameter(1, ShapeUtil::MakeShape(F32, {1LL << 62, 1}),
+ "weights");
+ auto a = builder.Dot(lhs, rhs);
+ auto b = builder.Dot(a, lhs);
+ builder.Dot(b, rhs);
+ auto hlo_module = BuildHloGraph(&builder);
+ ASSERT_IS_OK(hlo_module->entry_computation()->root_instruction()->Accept(
+ &matmul_analysis));
+ }
+
+ LOG(INFO) << matmul_analysis.flop_count();
+ EXPECT_GT(matmul_analysis.flop_count(), std::numeric_limits<int64>::max());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
new file mode 100644
index 0000000000..7c28ff9da1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -0,0 +1,134 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+namespace {
+
+// Find and combine identical constants. Constants are identical if they have
+// the same type and value.
+bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) {
+ bool changed = false;
+
+ // Map from ShortDebugString of the layoutless shape of the constant to the
+ // set of constant instructions with that shape. Layoutless shape is used to
+ // bin possible common constants together to reduce number of constant
+ // comparisons. If we end up having too many constant comparisons, a more
+ // precise binning might have to be used.
+ std::multimap<string, HloInstruction*> constants;
+
+ auto inst_it = computation->instructions().begin();
+ while (inst_it != computation->instructions().end()) {
+ HloInstruction* instruction = inst_it->get();
+
+ // Advance list iterator before loop body because iterator may be
+ // invalidated due to deletion.
+ ++inst_it;
+
+ if (instruction->opcode() == HloOpcode::kConstant) {
+ Shape shape = instruction->shape();
+ if (!is_layout_sensitive) {
+ LayoutUtil::ClearLayout(&shape);
+ }
+ string shape_string = shape.ShortDebugString();
+
+ // Compare against all constants with the same shape
+ auto range = constants.equal_range(shape_string);
+ HloInstruction* match = nullptr;
+ for (auto it = range.first; it != range.second; ++it) {
+ if (LiteralUtil::Equal(instruction->literal(), it->second->literal())) {
+ match = it->second;
+ break;
+ }
+ }
+ if (match == nullptr) {
+ constants.emplace(shape_string, instruction);
+ } else {
+ // Match found, replace this instruction with the one in the multimap.
+ computation->ReplaceUsesOfInstruction(instruction, match);
+ computation->RemoveInstruction(instruction);
+ changed = true;
+ }
+ }
+ }
+
+ return changed;
+}
+
+} // namespace
+
+StatusOr<bool> HloCSE::Run(HloModule* module) {
+ bool changed = false;
+ for (auto& computation : module->computations()) {
+ changed |= CombineConstants(computation.get(), is_layout_sensitive_);
+
+ std::list<HloInstruction*> post_order =
+ computation->MakeInstructionPostOrder();
+ std::set<HloInstruction*> removed_instructions;
+ for (auto instruction : post_order) {
+ // If the instruction has already been removed by CSE skip over it.
+ if (removed_instructions.count(instruction) > 0 ||
+ instruction->operand_count() == 0) {
+ continue;
+ }
+
+ // An instruction is considered to be equivalent to another only if they
+ // share the exact same set of operands. So to find equivalent
+ // instructions, we just search among instructions which share operand(0)
+ // of this instruction.
+ const HloInstruction* operand = instruction->operand(0);
+
+ std::vector<HloInstruction*> equivalent_instructions;
+ for (HloInstruction* user : operand->users()) {
+ if (user != instruction && user->Identical(*instruction) &&
+ (!is_layout_sensitive_ ||
+ ShapeUtil::Equal(user->shape(), instruction->shape()))) {
+ equivalent_instructions.push_back(user);
+ }
+ }
+
+ // Replace all equivalent instructions with this instruction.
+ for (HloInstruction* equivalent_instruction : equivalent_instructions) {
+ computation->ReplaceUsesOfInstruction(equivalent_instruction,
+ instruction);
+ computation->RemoveInstruction(equivalent_instruction);
+ removed_instructions.insert(equivalent_instruction);
+ changed = true;
+ }
+ }
+ }
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
new file mode 100644
index 0000000000..5b8b82462a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// A pass which performs common-subexpression elimination. Identical constants
+// and identical instructions with the same operands are commoned. The pass
+// iterates over the instructions in topological order which enables the pass to
+// find arbitrarily large common expressions.
+class HloCSE : public HloPass {
+ public:
+ // If is_layout_sensitive is true, then the simplifier preserves layout during
+ // transformation. Otherwise, layout is ignored.
+ explicit HloCSE(bool is_layout_sensitive)
+ : HloPass("cse"), is_layout_sensitive_(is_layout_sensitive) {}
+ ~HloCSE() override {}
+
+ // Run CSE on the given module. Returns whether the module was changed (common
+ // subexpressions were found and eliminated).
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ bool is_layout_sensitive_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CSE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
new file mode 100644
index 0000000000..ec8161f55f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -0,0 +1,428 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class HloCseTest : public HloTestBase {
+ protected:
+ HloCseTest() {}
+};
+
+TEST_F(HloCseTest, CombineTwoConstants) {
+ // Test that two identical constants are commoned.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(2, computation->instruction_count());
+ HloInstruction* constant = computation->instructions().begin()->get();
+ EXPECT_EQ(42.0f, LiteralUtil::Get<float>(constant->literal(), {}));
+
+ auto result = ExecuteAndTransfer(std::move(module), {});
+ auto expected = LiteralUtil::CreateR0<float>(84.0);
+ LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+}
+
+TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
+ // Test that two identical constants with different layouts are commoned if
+ // the pass is not layout sensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{0, 1})));
+ auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{1, 0})));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_NE(add->operand(0), add->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(2, computation->instruction_count());
+ EXPECT_EQ(add->operand(0), add->operand(1));
+
+ auto result = ExecuteAndTransfer(std::move(module), {});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+}
+
+TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
+ // Test that two identical constants with different layouts are *not* commoned
+ // if the pass is layout sensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{0, 1})));
+ auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{1, 0})));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(constant1, add->operand(0));
+ EXPECT_EQ(constant2, add->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(constant1, add->operand(0));
+ EXPECT_EQ(constant2, add->operand(1));
+
+ auto result = ExecuteAndTransfer(std::move(module), {});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+}
+
+TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
+ // Test that constants with the same value but different type are *not*
+ // commoned.
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ // Duplicate the float constant to verify something happens.
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(7, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(6, computation->instruction_count());
+}
+
+TEST_F(HloCseTest, NonscalarConstants) {
+ // Test that identical nonscalar constants are merged.
+ auto builder = HloComputation::Builder(TestName());
+ auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ // Create a constant which has the same shape but a different value.
+ auto uncommon_constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
+
+ // Tie the constants together with a tuple. This makes it easier to refer to
+ // the constant instructions via their use.
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
+ {common_constant1, common_constant2, uncommon_constant}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(4, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+ EXPECT_EQ(uncommon_constant, tuple->operand(2));
+ EXPECT_TRUE(tuple->operand(0) == common_constant1 ||
+ tuple->operand(0) == common_constant2);
+}
+
+TEST_F(HloCseTest, IdenticalInstructions) {
+ // Test that three identical instructions are commoned.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(5, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+ EXPECT_NE(tuple->operand(1), tuple->operand(2));
+ EXPECT_NE(tuple->operand(0), tuple->operand(2));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+ EXPECT_EQ(tuple->operand(1), tuple->operand(2));
+}
+
+TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
+ // Test that two identical instructions with different layouts are *not*
+ // commoned if the pass is layout sensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/true);
+ EXPECT_FALSE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+}
+
+TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
+ // Test that two identical instructions with different layouts are commoned if
+ // the pass is layout insensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+}
+
+TEST_F(HloCseTest, IdenticalExpressions) {
+ // Test that two identical expressions are commoned. Build the following
+ // computation:
+ //
+ // constant = 42.0
+ // negate1 = neg(constant)
+ // exp1 = exp(constant)
+ // add1 = add(negate1, exp1)
+ // negate2 = neg(constant)
+ // exp2 = exp(constant)
+ // add2 = add(negate2, exp2)
+ // tuple = tuple(add1, add2)
+ //
+ // The *1 instructions should be merged with the *2 instructions.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+
+ auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant->shape(), HloOpcode::kAdd, negate1, exp1));
+
+ auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant->shape(), HloOpcode::kAdd, negate2, exp2));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(8, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(5, computation->instruction_count());
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+ EXPECT_EQ(HloOpcode::kAdd, tuple->operand(0)->opcode());
+}
+
+TEST_F(HloCseTest, DoNotCombineRng) {
+ // Test that two RNG ops are not commoned.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
+ ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
+ {constant1, constant2}));
+ auto rng2 = builder.AddInstruction(HloInstruction::CreateRng(
+ ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
+ {constant1, constant2}));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, rng1, rng2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ uint32 count_before = computation->instruction_count();
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+
+ uint32 count_after = computation->instruction_count();
+ EXPECT_EQ(count_before, count_after);
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kRng);
+ EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kRng);
+ EXPECT_NE(root->operand(0), root->operand(1));
+}
+
+// TODO(b/28245743): Handle impure functions correctly in CSE.
+TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) {
+ // Test that two calls to an impure function are not commoned. RNG
+ // is the source of the impurity.
+
+ auto module = MakeUnique<HloModule>(TestName());
+
+ // rng_function is an impure function because it does RNG.
+ HloComputation* rng_function = nullptr;
+ {
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto builder = HloComputation::Builder(TestName() + "_rng_fun");
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ auto rng = builder.AddInstruction(HloInstruction::CreateRng(
+ scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "param"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, rng, param));
+ rng_function = module->AddEmbeddedComputation(builder.Build());
+ }
+
+ // Computation calls rng_function twice with the same parameter.
+ HloComputation* computation = nullptr;
+ {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
+ auto rng1 = builder.AddInstruction(
+ HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
+ auto rng2 = builder.AddInstruction(
+ HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant->shape(), HloOpcode::kAdd, rng1, rng2));
+ computation = module->AddEntryComputation(builder.Build());
+ }
+
+ EXPECT_EQ(4, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMap);
+ EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kMap);
+ EXPECT_NE(root->operand(0), root->operand(1));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
new file mode 100644
index 0000000000..056bbc2473
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+
+#include <memory>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+StatusOr<bool> HloDCE::Run(HloModule* module) {
+ bool changed = false;
+
+ for (auto& computation : module->computations()) {
+ std::unordered_set<HloInstruction*> live_instructions;
+ TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(
+ [&live_instructions](HloInstruction* instruction) {
+ live_instructions.insert(instruction);
+ return Status::OK();
+ }));
+
+ // Remove any dead roots and their dead transitive operands. Collect them
+ // into a separate list first to avoid problems with iterating through the
+ // computation's instruction while simultaneously removing instructions.
+ std::vector<HloInstruction*> dead_roots;
+ for (auto& instruction : computation->instructions()) {
+ if (instruction->user_count() == 0 &&
+ live_instructions.count(instruction.get()) == 0 &&
+ instruction->opcode() != HloOpcode::kParameter) {
+ dead_roots.push_back(instruction.get());
+ }
+ }
+
+ for (HloInstruction* dead_root : dead_roots) {
+ computation->RemoveInstructionAndUnusedOperands(dead_root);
+ changed = true;
+ }
+ }
+
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h
new file mode 100644
index 0000000000..53ba352890
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_dce.h
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DCE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DCE_H_
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+
+// HLO pass which removes all dead instructions from each computation in the
+// module. An instruction is dead if it is not reachable from the root. This
+// pass does not remove dead parameter instructions as parameter instructions
+// cannot be deleted, nor does the pass remove dead computations.
+class HloDCE : public HloPass {
+ public:
+ HloDCE() : HloPass("dce") {}
+ ~HloDCE() override {}
+
+ // Run the pass on the given module. Returns whether the module was changed
+ // (instructions were removed).
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_DCE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
new file mode 100644
index 0000000000..dcd9e00c56
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -0,0 +1,97 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_dce.h"
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class HloDceTest : public HloTestBase {
+ protected:
+ HloDceTest() {}
+};
+
+TEST_F(HloDceTest, NoDeadCode) {
+ // Verify that no dead code is removed from a computation with no dead code.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(123.0f)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+
+ HloDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+}
+
+TEST_F(HloDceTest, DeadParameters) {
+ // Verify that dead parameters are not removed, but use of the dead parameters
+ // are.
+ auto builder = HloComputation::Builder(TestName());
+ auto live_param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "live_param"));
+ auto dead_param1 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {}), "dead_param1"));
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 2, ShapeUtil::MakeShape(F32, {}), "dead_param2"));
+
+ // This is a dead negate instruction.
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ dead_param1->shape(), HloOpcode::kNegate, dead_param1));
+
+ // This negate is not dead because it is the root.
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ live_param->shape(), HloOpcode::kNegate, live_param));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(5, computation->instruction_count());
+ EXPECT_EQ(1, dead_param1->user_count());
+
+ HloDCE dce;
+ EXPECT_TRUE(dce.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_EQ(0, dead_param1->user_count());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
new file mode 100644
index 0000000000..edba55f6cd
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace xla {
+
+void HloExecutionProfile::AddProfileResult(const HloInstruction* hlo,
+ uint64 cycles_taken) {
+ hlo_to_cycles_taken_[hlo] = cycles_taken;
+}
+
+uint64 HloExecutionProfile::GetProfileResult(const HloInstruction& hlo) const {
+ auto iter = hlo_to_cycles_taken_.find(&hlo);
+ if (iter == hlo_to_cycles_taken_.end()) {
+ return 0;
+ }
+ return iter->second;
+}
+
+string HloExecutionProfile::ToString(
+ const DeviceDescription& device_description,
+ const HloCostAnalysis& cost_analysis) const {
+ using Item = std::pair<const HloInstruction*, uint64>;
+ std::vector<Item> items(hlo_to_cycles_taken_.begin(),
+ hlo_to_cycles_taken_.end());
+ auto custom_less = [](const Item& lhs, const Item& rhs) {
+ return lhs.second > rhs.second;
+ };
+ std::sort(items.begin(), items.end(), custom_less);
+ string result;
+ const int64 total_cycles = total_cycles_executed();
+ double clock_rate_ghz = device_description.clock_rate_ghz();
+ auto append_item = [&result, total_cycles, clock_rate_ghz](
+ int64 cycles, int64 flops, const string& name) {
+ double nsecs = cycles / clock_rate_ghz;
+ tensorflow::strings::StrAppend(
+ &result,
+ tensorflow::strings::Printf(
+ "%15lld cycles (%6.2f%%) :: %12.1f usec @ f_nom :: %18s :: %s",
+ cycles, cycles / static_cast<double>(total_cycles) * 100,
+ nsecs / 1e3,
+ flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
+ name.c_str()));
+ };
+ tensorflow::strings::StrAppend(
+ &result,
+ tensorflow::strings::Printf("HLO execution profile: (%s @ f_nom)\n\t",
+ tensorflow::strings::HumanReadableElapsedTime(
+ total_cycles / clock_rate_ghz / 1e9)
+ .c_str()));
+ append_item(total_cycles, -1, "[total]");
+ for (const auto& item : items) {
+ tensorflow::strings::StrAppend(&result, "\n\t");
+ auto flops = item.first == nullptr
+ ? -1
+ : cost_analysis.hlo_to_flop_count(*item.first);
+ string display = item.first == nullptr ? "<none>" : item.first->ToString();
+ append_item(item.second, flops, display);
+ }
+ return result;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h
new file mode 100644
index 0000000000..6cc2079813
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h
@@ -0,0 +1,71 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_
+
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class HloInstruction;
+
+// Describes how much time each HLO operation took.
+//
+// Each HloComputation takes a certain number of cycles. This class helps break
+// down how much time each HLO took.
+class HloExecutionProfile {
+ public:
+ using DeviceDescription = perftools::gputools::DeviceDescription;
+
+ // Record how many cycles this HLO took to execute.
+ void AddProfileResult(const HloInstruction* hlo, uint64 cycles_taken);
+
+ // Returns how many cycles this HLO took to execute. Profiling information
+ // may not be available for some instructions in which case zero is returned.
+ uint64 GetProfileResult(const HloInstruction& hlo) const;
+
+ // Return the number of cycles this computation took to execute.
+ uint64 total_cycles_executed() const { return total_cycles_executed_; }
+
+ // Record how many cycles the entire computation took to execute.
+ void set_total_cycles_executed(uint64 total_cycles_executed) {
+ total_cycles_executed_ = total_cycles_executed;
+ }
+
+ // Returns a version of the execution profile suitable for performance
+ // debugging; e.g. emits cycle counts, execution time at the nominal device
+ // frequency, and the effective throughput given the provided cost_analysis
+ // for the operations.
+ string ToString(const DeviceDescription& device_description,
+ const HloCostAnalysis& cost_analysis) const;
+
+ private:
+ // Contains a mapping from HLO to the number of cycles it took to execute it.
+ std::unordered_map<const HloInstruction*, uint64> hlo_to_cycles_taken_;
+
+ // If non-empty, contains the total number of cycles this computation took to
+ // execute.
+ uint64 total_cycles_executed_ = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EXECUTION_PROFILE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
new file mode 100644
index 0000000000..4865a8fb45
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -0,0 +1,507 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+using ::tensorflow::Env;
+using ::tensorflow::WriteStringToFile;
+using ::tensorflow::io::JoinPath;
+using ::tensorflow::strings::Appendf;
+using ::tensorflow::strings::Printf;
+using ::tensorflow::strings::StrAppend;
+using ::tensorflow::strings::StrCat;
+using ::tensorflow::str_util::Join;
+
+namespace xla {
+namespace hlo_graph_dumper {
+namespace {
+
+// Returns the dot graph identifier for the given instruction.
+string InstructionId(const HloInstruction* instruction) {
+ return Printf("%lld", reinterpret_cast<uint64>(instruction));
+}
+
+// Returns the dot graph identifier for the given computation.
+string ComputationId(const HloComputation* computation) {
+ return Printf("%lld", reinterpret_cast<uint64>(computation));
+}
+
+// Returns a compact string that represents the convolution dimension numbers.
+string ConvolutionDimensionNumbersToString(
+ const ConvolutionDimensionNumbers& dim_numbers) {
+ return Printf("B@%lld,Z@%lld,KIZ@%lld,KOZ@%lld",
+ dim_numbers.batch_dimension(), dim_numbers.feature_dimension(),
+ dim_numbers.kernel_input_feature_dimension(),
+ dim_numbers.kernel_output_feature_dimension());
+}
+
+// Returns a compact string that represents the non-trivial fields in the window
+// description. If there are no non-trivial fields, the empty string is
+// returned.
+string WindowToString(const Window& window) {
+ bool display_padding = false;
+ bool display_window_dilation = false;
+ bool display_base_dilation = false;
+ bool display_stride = false;
+ for (const WindowDimension& dimension : window.dimensions()) {
+ display_padding |=
+ dimension.padding_low() != 0 || dimension.padding_high() != 0;
+ display_window_dilation |= dimension.window_dilation() != 1;
+ display_base_dilation |= dimension.base_dilation() != 1;
+ display_stride |= dimension.stride() != 1;
+ }
+ std::vector<string> pieces = {};
+ if (display_padding) {
+ pieces.push_back("\\n");
+ pieces.push_back("padding=[");
+ for (const WindowDimension& dimension : window.dimensions()) {
+ pieces.push_back(StrCat("(", dimension.padding_low(), ",",
+ dimension.padding_high(), ")"));
+ pieces.push_back(", ");
+ }
+ pieces.pop_back();
+ pieces.push_back("]");
+ }
+ // Make a convenient lambda that adds a simple int64 field in each
+ // WindowDimension.
+ auto add_field = [&pieces, &window](
+ const string& label,
+ tensorflow::protobuf_int64 (WindowDimension::*member)() const) {
+ pieces.push_back("\\n");
+ pieces.push_back(label + "=[");
+ for (const WindowDimension& dimension : window.dimensions()) {
+ pieces.push_back(StrCat(((&dimension)->*member)()));
+ pieces.push_back(", ");
+ }
+ pieces.pop_back();
+ pieces.push_back("]");
+ };
+ if (display_window_dilation) {
+ add_field("window_dilation", &WindowDimension::window_dilation);
+ }
+ if (display_base_dilation) {
+ add_field("base_dilation", &WindowDimension::base_dilation);
+ }
+ if (display_stride) {
+ add_field("stride", &WindowDimension::stride);
+ }
+ return Join(pieces, "");
+}
+
+// Returns the dot graph edges and nodes for the given instruction sequence.
+// Edges which extend between computations are added to the vector
+// intercomputation_edges. This is necessary because graphviz does not render
+// the graph properly unless these inter-computation edges appear after all
+// subgraph statements.
+string InstructionSequenceGraph(
+ const std::list<std::unique_ptr<HloInstruction>>& instructions,
+ bool show_addresses, bool show_layouts,
+ std::vector<string>* intercomputation_edges,
+ const HloExecutionProfile* hlo_execution_profile) {
+ string graph_body;
+
+ // Create a single "record" node for the parameters. This node is a
+ // partitioned rectangle with one partition per parameter node. The keeps
+ // all the parameter instructions together.
+ std::vector<HloInstruction*> param_instructions;
+ for (auto& instruction : instructions) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ int64 param_number = instruction->parameter_number();
+ if (param_instructions.size() < param_number + 1) {
+ param_instructions.resize(param_number + 1, nullptr);
+ }
+ param_instructions[param_number] = instruction.get();
+ }
+ }
+ string param_node_name;
+ if (!param_instructions.empty()) {
+ std::vector<string> param_ports;
+ param_node_name =
+ StrCat("parameters_", InstructionId(param_instructions[0]));
+ for (auto& param : param_instructions) {
+ string label = StrCat(param->parameter_name(), "\\n",
+ ShapeUtil::HumanString(param->shape()));
+ if (show_addresses) {
+ Appendf(&label, "\\n[%p]", param);
+ }
+ if (show_layouts) {
+ StrAppend(&label, "\\nlayout=\\{",
+ Join(param->shape().layout().minor_to_major(), ","), "\\}");
+ }
+ param_ports.push_back(
+ Printf("<%s> %s", InstructionId(param).c_str(), label.c_str()));
+ }
+ StrAppend(&graph_body, param_node_name,
+ " [shape=record,style=filled,fillcolor=\"lightblue1\",",
+ "label=\"{parameters | {", Join(param_ports, "|"), "}}\"];\n");
+ }
+
+ for (auto& instruction : instructions) {
+ string color = "peachpuff";
+ string shape = "ellipse";
+ string name = HloOpcodeString(instruction->opcode());
+ if (HloOpcode::kFusion == instruction->opcode()) {
+ name += ": " + FusionKindString(instruction->fusion_kind());
+ }
+ if (HloOpcode::kConvolution == instruction->opcode()) {
+ name += ":\\n" + ConvolutionDimensionNumbersToString(
+ instruction->convolution_dimension_numbers()) +
+ WindowToString(instruction->window());
+ }
+ name += "\\n" + instruction->name();
+ std::vector<HloComputation*> called_computations;
+
+ // Pick different colors or shapes for instructions which are particularly
+ // expensive (eg, dot) and those which are unusual in some way or unique
+ // (eg, parameter).
+ switch (instruction->opcode()) {
+ // "Normal" instructions. Mostly cheap and elementwise. No call to
+ // embedded computations. In this case, use default color, shape and
+ // label.
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClamp:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kConvert:
+ case HloOpcode::kDivide:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kEq:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kIndex:
+ case HloOpcode::kLe:
+ case HloOpcode::kLog:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kLogicalOr:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kNegate:
+ case HloOpcode::kPad:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSign:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSort:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kTanh:
+ case HloOpcode::kTuple:
+ case HloOpcode::kUpdate:
+ break;
+
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kTranspose:
+ StrAppend(&name, "\\n", "dims={", Join(instruction->dimensions(), ","),
+ "}");
+ break;
+ case HloOpcode::kGetTupleElement:
+ StrAppend(&name, "\\nindex=", instruction->tuple_index());
+ break;
+ case HloOpcode::kRng:
+ StrAppend(&name, "\\n",
+ RandomDistribution_Name(instruction->random_distribution()));
+ break;
+ case HloOpcode::kConstant:
+ shape = "boxed";
+ color = "palegreen";
+ if (ShapeUtil::IsScalar(instruction->shape())) {
+ StrAppend(&name, "\\n", "value=", LiteralUtil::GetAsString(
+ instruction->literal(), {}));
+ }
+ break;
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCopy:
+ color = "white";
+ break;
+ case HloOpcode::kCall:
+ color = "tomato";
+ break;
+ case HloOpcode::kCustomCall:
+ color = "tomato4";
+ StrAppend(&name, "\\n",
+ "custom_call_target=", instruction->custom_call_target());
+ break;
+ case HloOpcode::kDot:
+ color = "slateblue";
+ break;
+ case HloOpcode::kSend:
+ color = "purple";
+ break;
+ case HloOpcode::kRecv:
+ color = "orange";
+ break;
+ case HloOpcode::kMap:
+ color = "palevioletred";
+ break;
+ case HloOpcode::kParameter:
+ // A single record node is created for all the parameter nodes with a
+ // port for each parameter instruction. No need to emit anything in this
+ // case.
+ continue;
+ case HloOpcode::kReduce:
+ StrAppend(&name, " dims=", Join(instruction->dimensions(), ","));
+ color = "lightsalmon";
+ break;
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kReduceWindow:
+ color = "lightsalmon";
+ break;
+ case HloOpcode::kTrace:
+ color = "white";
+ break;
+ case HloOpcode::kWhile:
+ color = "forestgreen";
+ break;
+ case HloOpcode::kFusion:
+ color = "gray";
+ break;
+ case HloOpcode::kConvolution:
+ color = "red";
+ break;
+ case HloOpcode::kCrossReplicaSum:
+ color = "turquoise";
+ break;
+ case HloOpcode::kInfeed:
+ color = "blue";
+ break;
+ }
+
+ // Create instruction node with appropriate label, shape, and color.
+ string label =
+ StrCat(name, "\\n", ShapeUtil::HumanString(instruction->shape()));
+ if (show_addresses) {
+ Appendf(&label, "\\n[%p]", instruction.get());
+ }
+ if (show_layouts && LayoutUtil::HasLayout(instruction->shape())) {
+ string layout_string;
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // For tuples, emit the full shape because the layout of a tuple is not
+ // represented in a single Layout field.
+ layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
+ } else {
+ layout_string =
+ Join(instruction->shape().layout().minor_to_major(), ",");
+ }
+ StrAppend(&label, "\\nlayout={", layout_string, "}");
+ }
+ if (hlo_execution_profile != nullptr) {
+ auto hlo_cycles_executed =
+ hlo_execution_profile->GetProfileResult(*instruction);
+ auto total_cycles_executed =
+ hlo_execution_profile->total_cycles_executed();
+ if (hlo_cycles_executed > 0 && total_cycles_executed > 0) {
+ Appendf(&label, "\\n%% of cycles executed=%.2f",
+ (static_cast<double>(hlo_cycles_executed) /
+ static_cast<double>(total_cycles_executed)) *
+ 100);
+ }
+ }
+ Appendf(&graph_body,
+ "%s [label=\"%s\", shape=%s, style=filled, fillcolor=%s];\n",
+ InstructionId(instruction.get()).c_str(), label.c_str(),
+ shape.c_str(), color.c_str());
+
+ // Create edges from the instruction's operands to the instruction.
+ int64 operand_number = 0;
+ for (auto* operand : instruction->operands()) {
+ string src;
+ if (operand->opcode() == HloOpcode::kParameter) {
+ // If operand is a parameter, then select the proper partition (port) in
+ // the unified parameter node.
+ src = param_node_name + ":" + InstructionId(operand);
+ } else {
+ src = InstructionId(operand);
+ }
+ Appendf(&graph_body, "%s -> %s", src.c_str(),
+ InstructionId(instruction.get()).c_str());
+ if (instruction->operand_count() > 1) {
+ Appendf(&graph_body, " [headlabel=\"%lld\",labeldistance=2]",
+ operand_number);
+ }
+ StrAppend(&graph_body, ";\n");
+ ++operand_number;
+ }
+
+ // Fusion nodes are handled specially because they contain nested
+ // expressions.
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ string cluster_name =
+ StrCat("cluster_", InstructionId(instruction.get()));
+ StrAppend(&graph_body, "subgraph ", cluster_name, " {\n");
+ StrAppend(&graph_body,
+ "label=\"fused expression\";\nstyle=filled;\n"
+ "color=lightgrey;\n");
+ StrAppend(&graph_body, InstructionSequenceGraph(
+ instruction->fused_instructions(),
+ show_addresses, show_layouts,
+ intercomputation_edges, hlo_execution_profile),
+ "}\n");
+ string fusion_edge =
+ StrCat(InstructionId(instruction->fused_expression_root()), " -> ",
+ InstructionId(instruction.get()),
+ " [ style = \"dotted\", arrowsize=0.0, ltail=", cluster_name,
+ " ];\n");
+ intercomputation_edges->push_back(fusion_edge);
+ } else {
+ // Add a dotted edge between the instruction and any computations that the
+ // instruction calls.
+ for (auto* computation : instruction->MakeCalledComputationsSet()) {
+ string cluster_name = StrCat("cluster_", ComputationId(computation));
+ string call_edge = Printf(
+ "%s -> %s [ style=dashed; ltail=%s ];\n",
+ InstructionId(computation->root_instruction()).c_str(),
+ InstructionId(instruction.get()).c_str(), cluster_name.c_str());
+ intercomputation_edges->push_back(call_edge);
+ }
+ }
+ }
+ return graph_body;
+}
+
+string ComputationToDotGraph(const HloComputation& computation,
+ const string& label, bool show_addresses,
+ bool show_layouts,
+ const HloExecutionProfile* hlo_execution_profile) {
+ string graph_label = StrCat(label, "\\n", computation.name());
+ if (hlo_execution_profile != nullptr) {
+ auto cycles = hlo_execution_profile->total_cycles_executed();
+ Appendf(&graph_label, "\\ntotal cycles = %lld (%s)", cycles,
+ tensorflow::strings::HumanReadableNum(cycles).c_str());
+ }
+ string graph =
+ Printf("digraph G {\nrankdir=TB;\ncompound=true;\nlabel=\"%s\"\n",
+ graph_label.c_str());
+
+ // Emit embedded computations as subgraph clusters.
+ std::vector<string> intercomputation_edges;
+ for (auto embedded : computation.MakeEmbeddedComputationsList()) {
+ string graph_body = InstructionSequenceGraph(
+ embedded->instructions(), show_addresses, show_layouts,
+ &intercomputation_edges, hlo_execution_profile);
+ Appendf(&graph, "subgraph cluster_%s {\nlabel=\"%s\";\n%s}\n",
+ ComputationId(embedded).c_str(), embedded->name().c_str(),
+ graph_body.c_str());
+ }
+ StrAppend(&graph,
+ InstructionSequenceGraph(computation.instructions(), show_addresses,
+ show_layouts, &intercomputation_edges,
+ hlo_execution_profile));
+
+ // Edges between computations (subgraph clusters) must be emitted last for the
+ // graph to be rendered properly for some reason.
+ StrAppend(&graph, Join(intercomputation_edges, "\n"), "}\n");
+
+ return graph;
+}
+
+tensorflow::mutex& RendererMutex() {
+ static tensorflow::mutex* mu = new tensorflow::mutex;
+ return *mu;
+}
+
+std::map<int, GraphRendererInterface*>* GraphRenderers() {
+ static auto* graph_renderers = new std::map<int, GraphRendererInterface*>();
+ return graph_renderers;
+}
+
+GraphRendererInterface* GetGraphRenderer() {
+ tensorflow::mutex_lock lock(RendererMutex());
+ auto* graph_renderers = GraphRenderers();
+ auto it = graph_renderers->rbegin();
+ CHECK(it != graph_renderers->rend()) << "No registered graph dumpers";
+ return it->second;
+}
+
+} // namespace
+
+Registrar::Registrar(GraphRendererInterface* dumper, int priority) {
+ tensorflow::mutex_lock lock(RendererMutex());
+ auto* graph_renderers = GraphRenderers();
+ graph_renderers->emplace(priority, dumper);
+}
+
+namespace {
+
+class FileGraphRenderer : public GraphRendererInterface {
+ public:
+ string RenderGraph(const string& graph) override {
+ static std::atomic<int> output_num(0);
+ legacy_flags::HloGraphDumperFlags* flags =
+ legacy_flags::GetHloGraphDumperFlags();
+ string path = StrCat(flags->xla_hlo_dump_graph_path, "hlo_graph_",
+ output_num++, ".dot");
+ tensorflow::Status status =
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph);
+ if (!status.ok()) {
+ LOG(WARNING) << "Saving HLO graph failed: " << status;
+ }
+ return path;
+ }
+};
+
+XLA_REGISTER_GRAPH_RENDERER(FileGraphRenderer, 0);
+
+} // namespace
+
+string DumpGraph(const HloComputation& computation, const string& label,
+ bool show_addresses, bool show_layouts,
+ const HloExecutionProfile* hlo_execution_profile) {
+ string graph = ComputationToDotGraph(computation, label, show_addresses,
+ show_layouts, hlo_execution_profile);
+
+ string graph_url = GetGraphRenderer()->RenderGraph(graph);
+ LOG(INFO) << "computation " << computation.name() << " [" << label
+ << "]: " << graph_url;
+ return graph_url;
+}
+
+void DumpText(const HloModule& module, const string& label,
+ const string& directory_path) {
+ Env* env = Env::Default();
+ TF_CHECK_OK(env->RecursivelyCreateDir(directory_path));
+ string prefix = StrCat(env->NowMicros());
+ string path = JoinPath(directory_path, StrCat(prefix, "-", label, ".txt"));
+ TF_CHECK_OK(WriteStringToFile(env, path, module.ToString()));
+}
+
+} // namespace hlo_graph_dumper
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.h b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
new file mode 100644
index 0000000000..45fd46352f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace hlo_graph_dumper {
+
+// Dumps a graph of the computation to the GraphViz server and returns
+// a description of the rendered graph (e.g., a URL).
+string DumpGraph(const HloComputation& computation, const string& label,
+ bool show_addresses, bool show_layouts,
+ const HloExecutionProfile* hlo_execution_profile = nullptr);
+
+// Dumps the HloModule::ToString() as a file into the provided directory path
+// suffixed with the provided label.
+void DumpText(const HloModule& module, const string& label,
+ const string& directory_path);
+
+// Abstract interface for classes that render DOT graphs.
+class GraphRendererInterface {
+ public:
+ virtual ~GraphRendererInterface() = default;
+
+ // Renders a DOT graph, returning a description of the rendered output
+ // (e.g., a URL)
+ virtual string RenderGraph(const string& graph) = 0;
+};
+
+// Graph renderers may be added using a registration mechanism, e.g.:
+// XLA_REGISTER_GRAPH_RENDERER(AGraphRendererClass, 100)
+// The renderer with the highest numeric priority value is used.
+
+#define XLA_REGISTER_GRAPH_RENDERER(factory, ...) \
+ XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, __COUNTER__, ##__VA_ARGS__)
+
+// Internal implementation details below this point.
+
+// Class that registers a graph renderer. Higher-priority renders are chosen
+// first.
+class Registrar {
+ public:
+ Registrar(GraphRendererInterface* dumper, int priority);
+};
+
+#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER(factory, ctr, ...) \
+ static ::xla::hlo_graph_dumper::Registrar \
+ XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr)(new factory, \
+ ##__VA_ARGS__)
+
+// __COUNTER__ must go through another macro to be properly expanded
+#define XLA_INTERNAL_REGISTER_GRAPH_RENDERER_NAME(ctr) ___##ctr##__object_
+
+} // namespace hlo_graph_dumper
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_GRAPH_DUMPER_H_
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
new file mode 100644
index 0000000000..7ae0a995af
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -0,0 +1,1921 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+#include <algorithm>
+#include <deque>
+#include <set>
+#include <unordered_set>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
+ int64 parameter_number, const Shape& shape, const string& name) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kParameter, shape));
+ instruction->parameter_number_ = parameter_number;
+ instruction->parameter_name_ = name;
+ instruction->name_ = "%" + name;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
+ const string& tag, HloInstruction* operand) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
+ instruction->operands_.push_back(operand);
+ instruction->literal_.reset(new Literal);
+ *instruction->literal_->mutable_u8s() += tag;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
+ std::unique_ptr<Literal> literal) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
+ instruction->literal_ = std::move(literal);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateGetTupleElement(const Shape& shape,
+ HloInstruction* operand, int64 index) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kGetTupleElement, shape));
+ instruction->tuple_index_ = index;
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
+ const Shape& shape, RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<HloInstruction*> parameters) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRng, shape));
+ instruction->distribution_ = distribution;
+ instruction->shape_ = shape;
+ for (HloInstruction* param : parameters) {
+ instruction->AppendOperand(param);
+ }
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
+ const Shape& shape, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ if (opcode == HloOpcode::kCopy) {
+ // It is impossible to copy an opaque shape, we don't know how big it is.
+ CHECK(!ShapeUtil::IsOpaque(shape));
+ }
+ auto instruction = WrapUnique(new HloInstruction(opcode, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
+ const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
+ // Only certain opcodes are supported with CreateUnary: opcodes of unary
+ // instructions with no auxiliary fields.
+ switch (opcode) {
+ case HloOpcode::kAbs:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCeil:
+ case HloOpcode::kCopy:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kLog:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kNegate:
+ case HloOpcode::kSign:
+ case HloOpcode::kSort:
+ case HloOpcode::kTanh:
+ break;
+ default:
+ LOG(FATAL) << "Invalid unary instruction opcode "
+ << HloOpcodeString(opcode);
+ }
+ return CreateNary(shape, opcode, {operand});
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
+ const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ // Only certain opcodes are supported with CreateBinary: opcodes of binary
+ // instructions with no auxiliary fields.
+ switch (opcode) {
+ case (HloOpcode::kAdd):
+ case (HloOpcode::kDivide):
+ case (HloOpcode::kDot):
+ case (HloOpcode::kEq):
+ case (HloOpcode::kGe):
+ case (HloOpcode::kGt):
+ case (HloOpcode::kLe):
+ case (HloOpcode::kLt):
+ case (HloOpcode::kMaximum):
+ case (HloOpcode::kMinimum):
+ case (HloOpcode::kMultiply):
+ case (HloOpcode::kNe):
+ case (HloOpcode::kPower):
+ case (HloOpcode::kRemainder):
+ case (HloOpcode::kSubtract):
+ case (HloOpcode::kLogicalAnd):
+ case (HloOpcode::kLogicalOr):
+ break;
+ default:
+ LOG(FATAL) << "Invalid binary instruction opcode "
+ << HloOpcodeString(opcode);
+ }
+ return CreateNary(shape, opcode, {lhs, rhs});
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
+ const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
+ HloInstruction* rhs, HloInstruction* ehs) {
+ // Only certain opcodes are supported with CreateTernary: opcodes of ternary
+ // instructions with no auxiliary fields.
+ switch (opcode) {
+ case (HloOpcode::kClamp):
+ case (HloOpcode::kSelect):
+ break;
+ default:
+ LOG(FATAL) << "Invalid ternary instruction opcode "
+ << HloOpcodeString(opcode);
+ }
+ return CreateNary(shape, opcode, {lhs, rhs, ehs});
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
+ const Shape& shape, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ CHECK_EQ(HloOpcode::kTuple, opcode);
+ return CreateNary(shape, opcode, operands);
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) {
+ CHECK(static_operands.empty()) << "static_operands not yet supported";
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kMap, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->to_apply_ = map_computation;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kConvolution, shape));
+ instruction->AppendOperand(lhs);
+ instruction->AppendOperand(rhs);
+ instruction->window_ = MakeUnique<Window>(window);
+ instruction->convolution_dimension_numbers_ =
+ MakeUnique<ConvolutionDimensionNumbers>(dimension_numbers);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateCrossReplicaSum(const Shape& shape,
+ HloInstruction* operand) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape));
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
+ const Shape& shape) {
+ return WrapUnique(new HloInstruction(HloOpcode::kInfeed, shape));
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
+ HloInstruction* operand) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
+ const Shape& shape) {
+ return WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReverse, shape));
+ instruction->AppendOperand(operand);
+ instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
+ const Shape& shape, HloComputation* condition, HloComputation* body,
+ HloInstruction* init) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
+ instruction->AppendOperand(init);
+ instruction->condition_ = condition;
+ instruction->body_ = body;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kSlice, shape));
+ instruction->AppendOperand(operand);
+ instruction->slice_starts_.assign(start_indices.begin(), start_indices.end());
+ instruction->slice_limits_.assign(limit_indices.begin(), limit_indices.end());
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
+ const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kDynamicSlice, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(start_indices);
+ instruction->dynamic_slice_sizes_.assign(slice_sizes.begin(),
+ slice_sizes.end());
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateDynamicUpdateSlice(const Shape& shape,
+ HloInstruction* operand,
+ HloInstruction* update,
+ HloInstruction* start_indices) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kDynamicUpdateSlice, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(update);
+ instruction->AppendOperand(start_indices);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ int64 dimension) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kConcatenate, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->dimensions_.push_back(dimension);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
+ const Shape& shape, HloInstruction* operand) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
+ const Shape& shape, HloInstruction* arg, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReduce, shape));
+ instruction->AppendOperand(arg);
+ instruction->AppendOperand(init_value);
+ instruction->dimensions_.assign(dimensions_to_reduce.begin(),
+ dimensions_to_reduce.end());
+ instruction->to_apply_ = reduce_computation;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
+ const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
+ const Window& window, HloComputation* reduce_computation) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kReduceWindow, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(init_value);
+ instruction->to_apply_ = reduce_computation;
+ instruction->window_ = MakeUnique<Window>(window);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateSelectAndScatter(
+ const Shape& shape, HloInstruction* operand, HloComputation* select,
+ const Window& window, HloInstruction* source, HloInstruction* init_value,
+ HloComputation* scatter) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kSelectAndScatter, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(source);
+ instruction->AppendOperand(init_value);
+ instruction->select_ = select;
+ instruction->scatter_ = scatter;
+ instruction->window_ = MakeUnique<Window>(window);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kBroadcast, shape));
+ instruction->AppendOperand(operand);
+ instruction->dimensions_.assign(broadcast_dimensions.begin(),
+ broadcast_dimensions.end());
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
+ const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
+ const PaddingConfig& padding_config) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kPad, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(padding_value);
+ instruction->padding_config_ = MakeUnique<PaddingConfig>(padding_config);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
+ const Shape& shape, HloInstruction* operand) {
+ CHECK_EQ(ShapeUtil::ElementsIn(shape),
+ ShapeUtil::ElementsIn(operand->shape()));
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
+ instruction->AppendOperand(operand);
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions) {
+ CHECK_EQ(shape.dimensions().size(), dimensions.size());
+ CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
+ CHECK(std::equal(operand->shape().dimensions().begin(),
+ operand->shape().dimensions().end(),
+ Permute(dimensions, shape.dimensions()).begin()));
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kTranspose, shape));
+ instruction->AppendOperand(operand);
+ instruction->dimensions_.assign(dimensions.begin(), dimensions.end());
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
+ const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
+ instruction->fusion_kind_ = fusion_kind;
+ instruction->CloneAndFuseInternal(fused_root);
+ instruction->CheckFusionInstruction();
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction>
+HloInstruction::CreateFusionForBackwardConvolution(
+ const Shape& shape, FusionKind fusion_kind, const Window& window,
+ const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) {
+ std::unique_ptr<HloInstruction> fusion =
+ CreateFusion(shape, fusion_kind, fused_root);
+ fusion->window_ = MakeUnique<Window>(window);
+ fusion->convolution_dimension_numbers_ =
+ MakeUnique<ConvolutionDimensionNumbers>(conv_dnums);
+ return fusion;
+}
+
+HloInstruction* HloInstruction::FuseInstruction(
+ HloInstruction* instruction_to_fuse) {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+
+ // This fusion instruction must be a user of instruction_to_fuse.
+ CHECK_NE(0, instruction_to_fuse->users().count(this));
+ HloInstruction* fused_instruction = CloneAndFuseInternal(instruction_to_fuse);
+ CheckFusionInstruction();
+ return fused_instruction;
+}
+
+HloInstruction* HloInstruction::CloneAndFuseInternal(
+ HloInstruction* instruction_to_fuse) {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+ CHECK(instruction_to_fuse->IsFusable());
+
+ bool new_fusion_instruction = fused_instructions_.empty();
+ fused_instructions_.emplace_back(instruction_to_fuse->Clone());
+ HloInstruction* clone = fused_instructions_.back().get();
+ clone->parent_fusion_instruction_ = this;
+
+ if (new_fusion_instruction) {
+ fused_root_ = clone;
+ } else {
+ // instruction_to_fuse is necessarily an operand of the fusion instruction.
+ // After fusion this will no longer be the case. Remove the operand from the
+ // operand list and remove its corresponding fused parameter
+ // instruction. Renumber parameters as necessary to make parameter numbers
+ // consistent with their index in the fused_parameter_ vector.
+ CHECK(std::find(operands_.begin(), operands_.end(), instruction_to_fuse) !=
+ operands_.end());
+ for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
+ if (instruction_to_fuse == operands_[operand_num]) {
+ // replace the fused parameter instruction's uses with the clone.
+ HloInstruction* fused_parameter = fused_parameters_[operand_num];
+ fused_parameter->ReplaceAllUsesWith(clone);
+
+ // Remove the corresponding fused parameter and operand from their
+ // respective vectors.
+ fused_parameters_.erase(fused_parameters_.begin() + operand_num);
+ operands_.erase(operands_.begin() + operand_num);
+
+ // Renumber fused parameter numbers to match the vector index.
+ while (operand_num < fused_parameters_.size()) {
+ fused_parameters_[operand_num]->parameter_number_ = operand_num;
+ operand_num++;
+ }
+ // Throw removed fused parameter instruction away.
+ auto inst_it =
+ std::find_if(fused_instructions_.begin(), fused_instructions_.end(),
+ [=](const std::unique_ptr<HloInstruction>& inst) {
+ return inst.get() == fused_parameter;
+ });
+ CHECK(inst_it != fused_instructions_.end());
+ fused_instructions_.erase(inst_it);
+ break;
+ }
+ }
+ // We've cloned instruction_to_fuse into this fusion instruction, so this
+ // fusion instruction is no longer a use of instruction_to_fuse.
+ instruction_to_fuse->RemoveUser(this);
+ }
+
+ // Add each operand of the clone as an operand of the fusion instruction. A
+ // complication is that some clone operands may already be operands of the
+ // fusion instruction.
+ for (int64 operand_num = 0; operand_num < clone->operand_count();
+ ++operand_num) {
+ HloInstruction* operand = clone->mutable_operand(operand_num);
+
+ // See if this operand is already an operand of the fusion node.
+ CHECK_EQ(operands_.size(), fused_parameters_.size());
+ HloInstruction* fused_param = nullptr;
+ for (int64 i = 0; i < operands_.size(); ++i) {
+ if (operands_[i] == operand) {
+ fused_param = fused_parameters_[i];
+ break;
+ }
+ }
+
+ if (fused_param == nullptr) {
+ // Clone's operand was not already an operand of the fusion
+ // instruction. Add it as an operand and add a corresponding fused
+ // parameter instruction.
+ int64 param_no = fused_parameters_.size();
+ std::unique_ptr<HloInstruction> param_instruction =
+ CreateParameter(param_no, operand->shape(), "fusion_param");
+
+ param_instruction->parent_fusion_instruction_ = this;
+ fused_parameters_.push_back(param_instruction.get());
+ fused_instructions_.push_back(std::move(param_instruction));
+ AppendOperand(operand);
+
+ fused_param = fused_instructions_.back().get();
+ }
+ clone->ReplaceOperandWith(operand_num, fused_param);
+ }
+
+ return clone;
+}
+
+RandomDistribution HloInstruction::random_distribution() const {
+ CHECK_EQ(opcode_, HloOpcode::kRng);
+ return distribution_;
+}
+
+namespace {
+
+// Adds any HloComputations this instruction calls directly to the given set.
+void CalledComputationsInternal(
+ const HloInstruction& instruction,
+ std::set<HloComputation*>* called_computations) {
+ switch (instruction.opcode()) {
+ case HloOpcode::kCall:
+ case HloOpcode::kMap:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ called_computations->insert(instruction.to_apply());
+ break;
+ case HloOpcode::kSelectAndScatter:
+ called_computations->insert(instruction.select());
+ called_computations->insert(instruction.scatter());
+ break;
+ case HloOpcode::kWhile:
+ called_computations->insert(instruction.while_condition());
+ called_computations->insert(instruction.while_body());
+ break;
+ case HloOpcode::kFusion:
+ for (const auto& fused_instruction : instruction.fused_instructions()) {
+ CalledComputationsInternal(*fused_instruction, called_computations);
+ }
+ break;
+ default:
+ break;
+ }
+}
+
+} // namespace
+
+std::set<HloComputation*> HloInstruction::MakeCalledComputationsSet() const {
+ std::set<HloComputation*> called_computations;
+ CalledComputationsInternal(*this, &called_computations);
+ return called_computations;
+}
+
+void HloInstruction::CheckFusionInstruction() const {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+
+ // All instructions owned by this fusion instruction must be fused, and the
+ // parent fusion instruction of the fused instructions must be 'this'.
+ for (auto& instruction : fused_instructions_) {
+ CHECK(instruction->IsFused());
+ CHECK_EQ(this, instruction->fusion_instruction());
+ }
+
+ // Fused root instruction and fused parameters must all be owned by the fusion
+ // instruction.
+ bool root_owned = false;
+ std::vector<bool> parameter_owned(fused_parameters_.size(), false);
+ for (auto& instruction : fused_instructions_) {
+ if (fused_root_ == instruction.get()) {
+ CHECK(!root_owned);
+ root_owned = true;
+ }
+ for (int i = 0; i < fused_parameters_.size(); ++i) {
+ if (fused_parameters_[i] == instruction.get()) {
+ CHECK(!parameter_owned[i]);
+ parameter_owned[i] = true;
+ }
+ }
+ }
+ CHECK(root_owned);
+ // Make sure all the parameter_owned entries are set
+ for (int i = 0; i < parameter_owned.size(); i++) {
+ CHECK(parameter_owned[i]);
+ }
+
+ // Fused root must have no users.
+ CHECK_EQ(0, fused_root_->user_count());
+
+ // All uses of fused instructions must be in the fusion instruction, and every
+ // non-root instruction must have at least one use.
+ for (auto& instruction : fused_instructions_) {
+ if (instruction.get() != fused_root_) {
+ CHECK_GT(instruction->user_count(), 0);
+ for (auto& user : instruction->users()) {
+ CHECK(user->IsFused());
+ CHECK_EQ(this, user->fusion_instruction());
+ }
+ }
+ }
+
+ // Fused parameter instructions must be numbered contiguously and match up
+ // (shapes compatible) with their respective operand.
+ CHECK_EQ(operands_.size(), fused_parameters_.size());
+ std::vector<bool> parameter_numbers(fused_parameters_.size(), false);
+ for (auto fused_param : fused_parameters_) {
+ int64 param_no = fused_param->parameter_number();
+ CHECK_GE(param_no, 0);
+ CHECK_LT(param_no, fused_parameters_.size());
+ CHECK(!parameter_numbers[param_no]);
+ parameter_numbers[param_no] = true;
+ CHECK(ShapeUtil::Compatible(fused_param->shape(),
+ operands_[param_no]->shape()));
+ }
+ // Make sure all the parameter_numbers entries were seen
+ for (int i = 0; i < parameter_numbers.size(); i++) {
+ CHECK(parameter_numbers[i]);
+ }
+
+ // Operands must be distinct.
+ std::set<HloInstruction*> operand_set(operands_.begin(), operands_.end());
+ CHECK_EQ(operand_set.size(), operands_.size());
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation) {
+ std::unique_ptr<HloInstruction> instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->to_apply_ = computation;
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target) {
+ std::unique_ptr<HloInstruction> instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kCustomCall, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->custom_call_target_ = custom_call_target.ToString();
+ return instruction;
+}
+
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
+ tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
+ std::vector<Shape> element_shapes;
+ for (auto element : elements) {
+ element_shapes.push_back(element->shape());
+ }
+ Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
+ return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
+}
+
+std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ // Explicitly call the factory for the instruction type. This is more robust
+ // in the face of code changes than copying fields explicitly. This also
+ // properly sets the user fields of the operands.
+ switch (opcode_) {
+ // Unary ops.
+ case HloOpcode::kAbs:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kCeil:
+ case HloOpcode::kCopy:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kLog:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kNegate:
+ case HloOpcode::kSign:
+ case HloOpcode::kSort:
+ case HloOpcode::kTanh:
+ CHECK_EQ(operands.size(), 1);
+ return CreateUnary(shape, opcode_, operands[0]);
+ // Binary ops.
+ case HloOpcode::kAdd:
+ case HloOpcode::kDivide:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kNe:
+ case HloOpcode::kDot:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalOr:
+ CHECK_EQ(operands.size(), 2);
+ return CreateBinary(shape, opcode_, operands[0], operands[1]);
+ // Ternary ops.
+ case HloOpcode::kClamp:
+ case HloOpcode::kSelect:
+ CHECK_EQ(operands.size(), 3);
+ return CreateTernary(shape, opcode_, operands[0], operands[1],
+ operands[2]);
+ // Other supported ops.
+ case HloOpcode::kBroadcast:
+ CHECK_EQ(operands.size(), 1);
+ return CreateBroadcast(shape, operands[0], dimensions_);
+ case HloOpcode::kCall:
+ return CreateCall(shape, operands, to_apply_);
+ case HloOpcode::kCustomCall:
+ return CreateCustomCall(shape, operands, custom_call_target_);
+ case HloOpcode::kConcatenate:
+ return CreateConcatenate(shape, operands, dimensions(0));
+ case HloOpcode::kConvert:
+ CHECK_EQ(operands.size(), 1);
+ return CreateConvert(shape, operands[0]);
+ case HloOpcode::kConvolution:
+ CHECK_EQ(operands.size(), 2);
+ return CreateConvolve(shape, operands[0], operands[1], *window_,
+ *convolution_dimension_numbers_);
+ case HloOpcode::kCrossReplicaSum:
+ CHECK_EQ(operands.size(), 1);
+ return CreateCrossReplicaSum(shape, operands[0]);
+ case HloOpcode::kGetTupleElement:
+ CHECK_EQ(operands.size(), 1);
+ return CreateGetTupleElement(shape, operands[0], tuple_index());
+ case HloOpcode::kMap:
+ return CreateMap(shape, operands, to_apply_);
+ case HloOpcode::kPad:
+ CHECK_EQ(operands.size(), 2);
+ return CreatePad(shape, operands[0], operands[1], *padding_config_);
+ case HloOpcode::kReduce:
+ CHECK_EQ(operands.size(), 2);
+ return CreateReduce(shape, operands[0], operands[1], dimensions_,
+ to_apply_);
+ case HloOpcode::kReduceWindow:
+ CHECK_EQ(operands.size(), 2);
+ return CreateReduceWindow(shape, operands[0], operands[1], *window_,
+ to_apply_);
+ case HloOpcode::kSelectAndScatter:
+ CHECK_EQ(operands.size(), 3);
+ return CreateSelectAndScatter(shape, operands[0], select_, *window_,
+ operands[1], operands[2], scatter_);
+ case HloOpcode::kRecv:
+ CHECK_EQ(operands.size(), 0);
+ return CreateRecv(shape);
+ case HloOpcode::kReverse:
+ CHECK_EQ(operands.size(), 1);
+ return CreateReverse(shape, operands[0], dimensions_);
+ case HloOpcode::kRng:
+ return CreateRng(shape, distribution_, operands);
+ case HloOpcode::kReshape:
+ CHECK_EQ(operands.size(), 1);
+ return CreateReshape(shape, operands[0]);
+ case HloOpcode::kSend:
+ CHECK_EQ(operands.size(), 1);
+ return CreateSend(operands[0]);
+ case HloOpcode::kSlice:
+ CHECK_EQ(operands.size(), 1);
+ return CreateSlice(shape, operands[0], slice_starts_, slice_limits_);
+ case HloOpcode::kDynamicSlice:
+ return CreateDynamicSlice(shape, operands[0], operands[1],
+ dynamic_slice_sizes_);
+ case HloOpcode::kDynamicUpdateSlice:
+ CHECK_EQ(operands.size(), 3);
+ return CreateDynamicUpdateSlice(shape, operands[0], operands[1],
+ operands[2]);
+ case HloOpcode::kTranspose:
+ CHECK_EQ(operands.size(), 1);
+ return CreateTranspose(shape, operands[0], dimensions_);
+ case HloOpcode::kTuple:
+ return CreateTuple(operands_);
+ case HloOpcode::kWhile:
+ CHECK_EQ(operands.size(), 1);
+ return CreateWhile(shape, condition_, body_, operands[0]);
+ case HloOpcode::kConstant:
+ return CreateConstant(LiteralUtil::CloneToUnique(*literal_));
+ case HloOpcode::kFusion:
+ return CloneFusionWithNewOperands(shape, operands);
+ case HloOpcode::kParameter:
+ return CreateParameter(parameter_number_, shape, parameter_name_);
+ // Unsupported ops for cloning.
+ case HloOpcode::kUpdate:
+ case HloOpcode::kIndex:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kTrace:
+ LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
+ }
+}
+
+std::unique_ptr<HloInstruction> HloInstruction::Clone() {
+ std::unique_ptr<HloInstruction> clone =
+ CloneWithNewOperands(shape_, operands_);
+ clone->name_ = name() + ".clone";
+ return clone;
+}
+
+std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+
+ auto new_instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
+ // Add the operands to our new fusion instruction.
+ for (HloInstruction* new_operand : operands) {
+ new_instruction->AppendOperand(new_operand);
+ }
+ // Clone all the fused instructions for the new fusion instruction.
+ std::map<HloInstruction*, HloInstruction*> old_to_new;
+ std::list<std::unique_ptr<HloInstruction>> new_fused_instructions;
+ // Create the list of fused parameters by mapping through the cloned,
+ // fused instructions.
+ std::vector<HloInstruction*> new_fused_parameters;
+ for (HloInstruction* old_fused_parameter : fused_parameters_) {
+ new_fused_instructions.push_back(old_fused_parameter->Clone());
+ HloInstruction* new_fusion_parameter = new_fused_instructions.back().get();
+ new_fusion_parameter->parent_fusion_instruction_ = new_instruction.get();
+ new_fused_parameters.push_back(new_fusion_parameter);
+ InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter);
+ }
+ for (auto old_fused_instruction_iter = fused_instructions_.rbegin();
+ old_fused_instruction_iter != fused_instructions_.rend();
+ ++old_fused_instruction_iter) {
+ HloInstruction* old_fused_instruction = old_fused_instruction_iter->get();
+ if (old_fused_instruction->opcode() == HloOpcode::kParameter) {
+ FindOrDie(old_to_new, old_fused_instruction);
+ continue;
+ }
+ std::vector<HloInstruction*> new_operands;
+ for (int64 operand_idx = 0;
+ operand_idx < old_fused_instruction->operand_count(); ++operand_idx) {
+ HloInstruction* old_operand =
+ old_fused_instruction->mutable_operand(operand_idx);
+ new_operands.push_back(FindOrDie(old_to_new, old_operand));
+ }
+ new_fused_instructions.push_back(
+ old_fused_instruction->CloneWithNewOperands(
+ old_fused_instruction->shape(), new_operands));
+ HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
+ new_fused_instruction->parent_fusion_instruction_ = new_instruction.get();
+ InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
+ }
+ // We iterated the fusion instructions in reverse post order which means
+ // that we must reverse our new list of fusion instructions.
+ std::reverse(new_fused_instructions.begin(), new_fused_instructions.end());
+ new_instruction->fusion_kind_ = fusion_kind_;
+ new_instruction->fused_instructions_ = std::move(new_fused_instructions);
+ new_instruction->fused_parameters_ = std::move(new_fused_parameters);
+ new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_);
+ new_instruction->CheckFusionInstruction();
+ return new_instruction;
+}
+
+const Literal& HloInstruction::literal() const {
+ CHECK_EQ(HloOpcode::kConstant, opcode_);
+ return *literal_;
+}
+
+bool HloInstruction::CanHaveDimensionsField() const {
+ return (opcode() == HloOpcode::kReverse ||
+ opcode() == HloOpcode::kConcatenate ||
+ opcode() == HloOpcode::kReduce || opcode() == HloOpcode::kBroadcast ||
+ opcode() == HloOpcode::kTranspose);
+}
+
+const std::vector<int64>& HloInstruction::dimensions() const {
+ CHECK(CanHaveDimensionsField());
+ return dimensions_;
+}
+
+int64 HloInstruction::dimensions(int64 index) const {
+ return dimensions()[index];
+}
+
+int64 HloInstruction::concatenate_dimension() const {
+ CHECK(opcode() == HloOpcode::kConcatenate);
+ CHECK_EQ(1, dimensions_.size());
+ return dimensions(0);
+}
+
+int64 HloInstruction::tuple_index() const {
+ CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
+ return tuple_index_;
+}
+
+const HloInstruction* HloInstruction::operand(int64 i) const {
+ return operands_[i];
+}
+
+HloInstruction* HloInstruction::mutable_operand(int64 i) {
+ CHECK(operands_[i] != nullptr);
+ return operands_[i];
+}
+
+int64 HloInstruction::operand_index(const HloInstruction* target) const {
+ for (int64 i = 0; i < operand_count(); ++i) {
+ if (target == operand(i)) {
+ return i;
+ }
+ }
+ LOG(FATAL) << "target was not an operand";
+}
+
+void HloInstruction::AppendOperand(HloInstruction* operand) {
+ operands_.push_back(operand);
+ operand->AddUser(this);
+}
+
+void HloInstruction::AddUser(HloInstruction* user) { users_.insert(user); }
+
+bool HloInstruction::IsConstant() const {
+ return opcode_ == HloOpcode::kConstant;
+}
+
+bool HloInstruction::HasConstantOperand() const {
+ for (const HloInstruction* operand : operands_) {
+ if (operand->IsConstant()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void HloInstruction::AddControlPredecessor(HloInstruction* instruction) {
+ control_predecessors_.insert(instruction);
+}
+
+bool HloInstruction::Identical(
+ const HloInstruction& other,
+ std::function<bool(const HloInstruction*, const HloInstruction*)>
+ eq_operands,
+ std::function<bool(const HloComputation*, const HloComputation*)>
+ eq_computations) const {
+ // An instruction is always identical to itself.
+ if (this == &other) {
+ return true;
+ }
+
+ // Identical instruction must have the same opcode and identical operands. In
+ // general, there is no need to check shape because shape is inferred from the
+ // shape of the operands.
+ if (opcode() != other.opcode() ||
+ !ContainersEqual(operands(), other.operands(), eq_operands)) {
+ return false;
+ }
+
+ // Perform opcode specific checks.
+ switch (opcode()) {
+ // The result of these instructions only depend upon their opcode and
+ // operands.
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClamp:
+ case HloOpcode::kCopy:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kDivide:
+ case HloOpcode::kDot:
+ case HloOpcode::kEq:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLog:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kLogicalOr:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kNegate:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSign:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kTanh:
+ case HloOpcode::kTuple:
+ return true;
+
+ // These opcodes have complex or special behavior so just return false.
+ case HloOpcode::kFusion:
+ case HloOpcode::kRng:
+ case HloOpcode::kTrace:
+ case HloOpcode::kWhile:
+ return false;
+
+ case HloOpcode::kParameter:
+ return parameter_number() == other.parameter_number() &&
+ // Check the shape too because `this` and `other` may be in
+ // different HloComputations.
+ ShapeUtil::Compatible(shape(), other.shape());
+
+ // A constant is defined by the value in the literal.
+ case HloOpcode::kConstant:
+ return LiteralUtil::Equal(literal(), other.literal());
+
+ // A convert result is determined by the primitive type that the operand is
+ // converted into.
+ case HloOpcode::kConvert:
+ return shape().element_type() == other.shape().element_type();
+
+ // Convolution has a window and dimensions.
+ case HloOpcode::kConvolution:
+ return protobuf_util::ProtobufEquals(window(), other.window()) &&
+ protobuf_util::ProtobufEquals(
+ convolution_dimension_numbers(),
+ other.convolution_dimension_numbers());
+
+ // Reduction results are determined by the reduction dimension and the
+ // reduction computation.
+ case HloOpcode::kReduce:
+ return dimensions() == other.dimensions() &&
+ eq_computations(to_apply(), other.to_apply());
+ case HloOpcode::kReduceWindow:
+ return eq_computations(to_apply(), other.to_apply()) &&
+ protobuf_util::ProtobufEquals(window(), other.window());
+
+ // SelectAndScatter is determined by both select and scatter
+ // computation as well as the window configuration.
+ case HloOpcode::kSelectAndScatter:
+ return eq_computations(select(), other.select()) &&
+ eq_computations(scatter(), other.scatter()) &&
+ protobuf_util::ProtobufEquals(window(), other.window());
+
+ case HloOpcode::kReshape:
+ return ShapeUtil::Compatible(shape(), other.shape());
+
+ // Transpose result is determined by the final shape and the permutation.
+ case HloOpcode::kTranspose:
+ return ShapeUtil::Compatible(shape(), other.shape()) &&
+ dimensions() == other.dimensions();
+
+ // Remaining instructions with special values.
+ case HloOpcode::kBitcast:
+ return ShapeUtil::Equal(shape(), other.shape());
+ case HloOpcode::kBroadcast:
+ return ShapeUtil::Compatible(shape(), other.shape()) &&
+ dimensions() == other.dimensions();
+ case HloOpcode::kConcatenate:
+ return dimensions() == other.dimensions();
+ case HloOpcode::kGetTupleElement:
+ return tuple_index() == other.tuple_index();
+ case HloOpcode::kPad:
+ return protobuf_util::ProtobufEquals(padding_config(),
+ other.padding_config());
+ case HloOpcode::kSlice:
+ return slice_starts_ == other.slice_starts_ &&
+ slice_limits_ == other.slice_limits_;
+ case HloOpcode::kDynamicSlice:
+ return ShapeUtil::Compatible(shape(), other.shape()) &&
+ dynamic_slice_sizes_ == other.dynamic_slice_sizes_;
+ case HloOpcode::kDynamicUpdateSlice:
+ return ShapeUtil::Compatible(shape(), other.shape());
+ case HloOpcode::kCall:
+ case HloOpcode::kMap:
+ return eq_computations(to_apply(), other.to_apply());
+ case HloOpcode::kCustomCall:
+ return custom_call_target_ == other.custom_call_target_;
+ case HloOpcode::kReverse:
+ return dimensions() == other.dimensions();
+
+ // These opcodes are not yet supported.
+ case HloOpcode::kIndex:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kSort:
+ case HloOpcode::kUpdate:
+ case HloOpcode::kSend:
+ case HloOpcode::kRecv:
+ return false;
+ }
+}
+
+bool HloInstruction::IsRank2Transpose() const {
+ return (opcode_ == HloOpcode::kTranspose) &&
+ dimensions_ == std::vector<int64>({1, 0}) &&
+ shape_.dimensions_size() == 2 &&
+ std::equal(shape_.dimensions().begin(), shape_.dimensions().end(),
+ operands_[0]->shape_.dimensions().rbegin());
+}
+
+void HloInstruction::RemoveUser(HloInstruction* user) {
+ auto user_it = users_.find(user);
+ CHECK(user_it != users_.end());
+ users_.erase(user_it);
+}
+
+void HloInstruction::ReplaceUseWith(HloInstruction* user,
+ HloInstruction* new_producer) {
+ CHECK(ShapeUtil::Compatible(shape(), new_producer->shape()))
+ << "this shape: " << ShapeUtil::HumanString(shape())
+ << ", replacement shape: "
+ << ShapeUtil::HumanString(new_producer->shape());
+ auto user_it = std::find(users_.begin(), users_.end(), user);
+ CHECK(user_it != users_.end()) << "Instruction " << user
+ << " not a use of instruction " << this;
+ users_.erase(user_it);
+
+ CHECK_GT(std::count(user->operands_.begin(), user->operands_.end(), this), 0);
+ std::replace(user->operands_.begin(), user->operands_.end(), this,
+ new_producer);
+ new_producer->AddUser(user);
+}
+
+void HloInstruction::ReplaceOperandWith(int64 operand_num,
+ HloInstruction* new_operand) {
+ CHECK_GE(operand_num, 0);
+ CHECK_LT(operand_num, operand_count());
+ HloInstruction* old_operand = mutable_operand(operand_num);
+ CHECK(ShapeUtil::Compatible(old_operand->shape(), new_operand->shape()))
+ << old_operand->shape().ShortDebugString() << " is not compatible with "
+ << new_operand->shape().ShortDebugString();
+ operands_[operand_num] = new_operand;
+
+ if (std::find(operands_.begin(), operands_.end(), old_operand) ==
+ operands_.end()) {
+ old_operand->RemoveUser(this);
+ }
+ new_operand->AddUser(this);
+}
+
+void HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
+ // We can't use range-based loop because the iterator is invalidated by call
+ // to ReplaceUseWith.
+ for (auto user = users_.begin(); user != users_.end();) {
+ auto this_user = user;
+ user++;
+ // It's possible that new_producer is a user of this instruction as might
+ // be the case when replacing an instruction with a kCopy of itself. In
+ // this case, don't do the replacement to avoid creating a cycle in the
+ // graph.
+ if (*this_user != new_producer) {
+ ReplaceUseWith(*this_user, new_producer);
+ }
+ }
+}
+
+void HloInstruction::DetachFromOperands() {
+ CHECK_EQ(0, user_count());
+ // An intruction may be repeated as an operand. To avoid calling RemoveUser
+ // twice on the same operand, keep a set of already detached operands.
+ std::set<HloInstruction*> detached_operands;
+ for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
+ HloInstruction* operand = operands_[operand_num];
+ if (detached_operands.count(operand) == 0) {
+ operand->RemoveUser(this);
+ detached_operands.insert(operand);
+ }
+ operands_[operand_num] = nullptr;
+ }
+}
+
+HloComputation* HloInstruction::to_apply() const {
+ switch (opcode_) {
+ case HloOpcode::kCall:
+ case HloOpcode::kMap:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kReduce:
+ return to_apply_;
+ default:
+ LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString();
+ }
+}
+
+void HloInstruction::set_to_apply(HloComputation* computation) {
+ switch (opcode_) {
+ case HloOpcode::kCall:
+ case HloOpcode::kMap:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kReduce:
+ to_apply_ = computation;
+ break;
+ default:
+ LOG(FATAL) << "Invalid instruction for to_apply(): " << ToString();
+ }
+}
+
+const string& HloInstruction::custom_call_target() const {
+ CHECK_EQ(opcode_, HloOpcode::kCustomCall);
+ return custom_call_target_;
+}
+
+HloComputation* HloInstruction::while_condition() const {
+ CHECK_EQ(HloOpcode::kWhile, opcode_);
+ return condition_;
+}
+
+HloComputation* HloInstruction::while_body() const {
+ CHECK_EQ(HloOpcode::kWhile, opcode_);
+ return body_;
+}
+
+void HloInstruction::set_while_condition(HloComputation* computation) {
+ CHECK_EQ(HloOpcode::kWhile, opcode_);
+ condition_ = computation;
+}
+
+void HloInstruction::set_while_body(HloComputation* computation) {
+ CHECK_EQ(HloOpcode::kWhile, opcode_);
+ body_ = computation;
+}
+
+HloComputation* HloInstruction::select() const {
+ CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
+ return select_;
+}
+
+HloComputation* HloInstruction::scatter() const {
+ CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
+ return scatter_;
+}
+
+void HloInstruction::set_select(HloComputation* computation) {
+ CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
+ select_ = computation;
+}
+
+void HloInstruction::set_scatter(HloComputation* computation) {
+ CHECK_EQ(HloOpcode::kSelectAndScatter, opcode_);
+ scatter_ = computation;
+}
+
+string HloInstruction::SignatureString() const {
+ string operands = tensorflow::str_util::Join(
+ operands_, ", ", [](string* out, HloInstruction* operand) {
+ tensorflow::strings::StrAppend(
+ out, ShapeUtil::HumanString(operand->shape()));
+ });
+ return tensorflow::strings::StrCat("(", operands, ") -> ",
+ ShapeUtil::HumanString(shape()));
+}
+
+string HloInstruction::ToString() const {
+ string operands;
+ if (opcode() == HloOpcode::kConstant) {
+ // For constants, emit the actual value in place of an empty operand list.
+ if (ShapeUtil::ElementsIn(shape()) <= 10) {
+ // LiteralUtil::ToString emits multidimensional arrays over multiple
+ // lines. Compact this into one line by stripping out white space.
+ string tmp = LiteralUtil::ToString(literal());
+ std::replace(tmp.begin(), tmp.end(), '\n', ' ');
+ std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
+ bool first = true;
+ // Concatenate elements in "v" with spaces separating them, but ignoring
+ // empty entries.
+ for (const auto& s : v) {
+ if (s.empty()) continue;
+ tensorflow::strings::StrAppend(&operands, (first ? "" : " "), s);
+ first = false;
+ }
+ } else {
+ // Don't try emitting large constants.
+ operands = "{...}";
+ }
+ } else {
+ operands = tensorflow::str_util::Join(
+ operands_, ", ", [](string* out, HloInstruction* operand) {
+ tensorflow::strings::StrAppend(
+ out, ShapeUtil::HumanStringWithLayout(operand->shape()), " ",
+ operand->name());
+ });
+ }
+ string extra;
+ if (LayoutUtil::HasLayout(shape())) {
+ if (ShapeUtil::IsTuple(shape())) {
+ // Tuple shapes are recursive, so the layout field of the top-level shape
+ // does not include all layout information. In this case, print out the
+ // entire shape with layout.
+ tensorflow::strings::StrAppend(&extra, ", layout=",
+ ShapeUtil::HumanStringWithLayout(shape()));
+ } else {
+ tensorflow::strings::StrAppend(
+ &extra, tensorflow::strings::Printf(
+ ", layout=%s",
+ LayoutUtil::HumanString(shape().layout()).c_str()));
+ }
+ }
+ if (CanHaveDimensionsField()) {
+ tensorflow::strings::StrAppend(
+ &extra, ", dimensions={",
+ tensorflow::str_util::Join(dimensions(), ", "), "}");
+ }
+ if (window_ != nullptr) {
+ tensorflow::strings::StrAppend(&extra, ", window=",
+ window_util::ToString(*window_));
+ }
+ if (padding_config_ != nullptr) {
+ tensorflow::strings::StrAppend(&extra, ", padding=",
+ padding_config_->ShortDebugString());
+ }
+ if (convolution_dimension_numbers_ != nullptr) {
+ tensorflow::strings::StrAppend(
+ &extra,
+ tensorflow::strings::Printf(
+ ", "
+ "conv_dim_nums={batch_dim=%lld,feature_dim=%lld,spatial_dims=(%s),"
+ "kernel_input_feature_dims=%lld,kernel_output_feature_dim=%lld,"
+ "kernel_spatial_dims=(%s)}",
+ convolution_dimension_numbers_->batch_dimension(),
+ convolution_dimension_numbers_->feature_dimension(),
+ tensorflow::str_util::Join(
+ convolution_dimension_numbers_->spatial_dimensions(), ",")
+ .c_str(),
+ convolution_dimension_numbers_->kernel_input_feature_dimension(),
+ convolution_dimension_numbers_->kernel_output_feature_dimension(),
+ tensorflow::str_util::Join(
+ convolution_dimension_numbers_->kernel_spatial_dimensions(),
+ ",")
+ .c_str()));
+ }
+ if (to_apply_ != nullptr) {
+ tensorflow::strings::StrAppend(&extra, ", computation=", to_apply_->name());
+ }
+ if (opcode() == HloOpcode::kWhile) {
+ tensorflow::strings::StrAppend(&extra, ", condition=",
+ while_condition()->name());
+ tensorflow::strings::StrAppend(&extra, ", body=", while_body()->name());
+ }
+ if (opcode() == HloOpcode::kGetTupleElement) {
+ tensorflow::strings::StrAppend(&extra, ", index=", tuple_index());
+ }
+ return tensorflow::strings::Printf(
+ "%s %s = %s(%s)%s", ShapeUtil::HumanString(shape()).c_str(),
+ name().c_str(), HloOpcodeString(opcode()).c_str(), operands.c_str(),
+ extra.c_str());
+}
+
+string HloInstruction::ToShortString() const {
+ return tensorflow::strings::Printf(
+ "%s = %s(%s)", name().c_str(), HloOpcodeString(opcode()).c_str(),
+ tensorflow::str_util::Join(operands_, ", ", [](string* out,
+ HloInstruction* operand) {
+ tensorflow::strings::StrAppend(out, operand->name());
+ }).c_str());
+}
+
+HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
+
+void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
+ trace_instruction_ = trace_instruction;
+}
+
+const string& HloInstruction::tracing_tag() const {
+ CHECK_EQ(HloOpcode::kTrace, opcode());
+ CHECK(literal_ != nullptr);
+ return literal_->u8s();
+}
+
+bool HloInstruction::IsFused() const {
+ return parent_fusion_instruction_ != nullptr;
+}
+
+bool HloInstruction::IsFusable() const {
+ // Instructions which are traced should not be fused.
+ if (tracing()) {
+ return false;
+ }
+
+ // Some kinds of instructions don't make sense to fuse.
+ switch (opcode_) {
+ case HloOpcode::kFusion:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kParameter:
+ case HloOpcode::kTrace:
+ case HloOpcode::kSend:
+ case HloOpcode::kRecv:
+ return false;
+ default:
+ return true;
+ }
+}
+
+HloInstruction* HloInstruction::fusion_instruction() const {
+ CHECK(IsFused());
+ return parent_fusion_instruction_;
+}
+
+HloInstruction* HloInstruction::fused_expression_root() const {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+ return fused_root_;
+}
+
+HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+ CHECK_GE(parameter_number, 0);
+ CHECK_LT(parameter_number, fused_parameters_.size());
+ return fused_parameters_[parameter_number];
+}
+
+const std::list<std::unique_ptr<HloInstruction>>&
+HloInstruction::fused_instructions() const {
+ CHECK_EQ(opcode_, HloOpcode::kFusion);
+ return fused_instructions_;
+}
+
+HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
+ : shape_(shape), opcode_(opcode), name_("%" + HloOpcodeString(opcode)) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
+}
+
+Status HloInstruction::AcceptInternalVisit(DfsHloVisitor* visitor) {
+ switch (opcode_) {
+ case HloOpcode::kAbs:
+ return visitor->HandleAbs(this, operands_[0]);
+ case HloOpcode::kSign:
+ return visitor->HandleSign(this, operands_[0]);
+ case HloOpcode::kConstant:
+ return visitor->HandleConstant(this, *literal_);
+ case HloOpcode::kGetTupleElement:
+ return visitor->HandleGetTupleElement(this, operands_[0]);
+ case HloOpcode::kParameter:
+ return visitor->HandleParameter(this);
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kNe:
+ return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]);
+ case HloOpcode::kAdd:
+ return visitor->HandleAdd(this, operands_[0], operands_[1]);
+ case HloOpcode::kDivide:
+ return visitor->HandleDivide(this, operands_[0], operands_[1]);
+ case HloOpcode::kSubtract:
+ return visitor->HandleSubtract(this, operands_[0], operands_[1]);
+ case HloOpcode::kMaximum:
+ return visitor->HandleMaximum(this, operands_[0], operands_[1]);
+ case HloOpcode::kMinimum:
+ return visitor->HandleMinimum(this, operands_[0], operands_[1]);
+ case HloOpcode::kLogicalAnd:
+ return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]);
+ case HloOpcode::kLogicalOr:
+ return visitor->HandleLogicalOr(this, operands_[0], operands_[1]);
+ case HloOpcode::kConcatenate:
+ return visitor->HandleConcatenate(this, operands_);
+ case HloOpcode::kConvert:
+ return visitor->HandleConvert(this, operands_[0]);
+ case HloOpcode::kCopy:
+ return visitor->HandleCopy(this, operands_[0]);
+ case HloOpcode::kMultiply:
+ return visitor->HandleMultiply(this, operands_[0], operands_[1]);
+ case HloOpcode::kDot:
+ return visitor->HandleDot(this, operands_[0], operands_[1]);
+ case HloOpcode::kPower:
+ return visitor->HandlePower(this, operands_[0], operands_[1]);
+ case HloOpcode::kRemainder:
+ return visitor->HandleRemainder(this, operands_[0], operands_[1]);
+ case HloOpcode::kSelect:
+ return visitor->HandleSelect(this, operands_[0], operands_[1],
+ operands_[2]);
+ case HloOpcode::kConvolution:
+ return visitor->HandleConvolution(this, operands_[0], operands_[1],
+ window());
+ case HloOpcode::kCrossReplicaSum:
+ return visitor->HandleCrossReplicaSum(this);
+ case HloOpcode::kTuple:
+ return visitor->HandleTuple(this, operands_);
+ case HloOpcode::kMap:
+ return visitor->HandleMap(this, operands_, to_apply_, {});
+ case HloOpcode::kClamp:
+ return visitor->HandleClamp(this, operands_[0], operands_[1],
+ operands_[2]);
+ case HloOpcode::kReduce:
+ return visitor->HandleReduce(this, operands_[0], operands_[1],
+ dimensions_, to_apply_);
+ case HloOpcode::kReduceWindow:
+ return visitor->HandleReduceWindow(this, operands_[0], window(),
+ to_apply_);
+ case HloOpcode::kSelectAndScatter:
+ return visitor->HandleSelectAndScatter(this);
+ case HloOpcode::kNegate:
+ return visitor->HandleNegate(this, operands_[0]);
+ case HloOpcode::kExp:
+ return visitor->HandleExp(this, operands_[0]);
+ case HloOpcode::kFloor:
+ return visitor->HandleFloor(this, operands_[0]);
+ case HloOpcode::kCeil:
+ return visitor->HandleCeil(this, operands_[0]);
+ case HloOpcode::kLog:
+ return visitor->HandleLog(this, operands_[0]);
+ case HloOpcode::kTanh:
+ return visitor->HandleTanh(this, operands_[0]);
+ case HloOpcode::kLogicalNot:
+ return visitor->HandleLogicalNot(this, operands_[0]);
+ case HloOpcode::kBitcast:
+ return visitor->HandleBitcast(this);
+ case HloOpcode::kBroadcast:
+ return visitor->HandleBroadcast(this);
+ case HloOpcode::kPad:
+ return visitor->HandlePad(this);
+ case HloOpcode::kReshape:
+ return visitor->HandleReshape(this);
+ case HloOpcode::kTranspose:
+ return visitor->HandleTranspose(this);
+ case HloOpcode::kReverse:
+ return visitor->HandleReverse(this, operands_[0]);
+ case HloOpcode::kSlice:
+ return visitor->HandleSlice(this, operands_[0]);
+ case HloOpcode::kDynamicSlice:
+ return visitor->HandleDynamicSlice(this, operands_);
+ case HloOpcode::kDynamicUpdateSlice:
+ return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1],
+ operands_[2]);
+ case HloOpcode::kSort:
+ return visitor->HandleSort(this, operands_[0]);
+ case HloOpcode::kInfeed:
+ return visitor->HandleInfeed(this);
+ case HloOpcode::kRng:
+ return visitor->HandleRng(this, distribution_);
+ case HloOpcode::kWhile:
+ return visitor->HandleWhile(this, operands_[0], condition_, body_);
+ case HloOpcode::kFusion:
+ return visitor->HandleFusion(this);
+ case HloOpcode::kCall:
+ return visitor->HandleCall(this, operands_, to_apply_);
+ case HloOpcode::kCustomCall:
+ return visitor->HandleCustomCall(this, operands_, custom_call_target_);
+ case HloOpcode::kSend:
+ return visitor->HandleSend(this);
+ case HloOpcode::kRecv:
+ return visitor->HandleRecv(this);
+
+ // These opcodes are not handled here.
+ case HloOpcode::kIndex:
+ case HloOpcode::kTrace:
+ case HloOpcode::kUpdate:
+ break;
+ }
+ return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s",
+ HloOpcodeString(opcode_).c_str());
+}
+
+Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor) {
+ // Do not visit this HLO node again if it is already visited.
+ if (visitor->DidVisit(*this)) {
+ VLOG(3) << "Not visiting HLO " << this << " as it was already visited.";
+ return Status::OK();
+ }
+
+ // If the instruction is in the visiting state, it means a cycle.
+ if (visitor->IsVisiting(*this)) {
+ return FailedPrecondition(
+ "A cycle is detected while visiting instruction %s",
+ ToString().c_str());
+ }
+ visitor->SetVisiting(*this);
+
+ for (auto operand : operands_) {
+ VLOG(3) << "Going to visit HLO " << operand << " as operand of HLO "
+ << this;
+ TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor));
+ }
+
+ for (auto control_predecessor : control_predecessors_) {
+ VLOG(3) << "Going to visit HLO " << control_predecessor
+ << " as a control predecessor of HLO " << this;
+ TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal(visitor));
+ }
+
+ TF_RETURN_IF_ERROR(visitor->Preprocess(this));
+ VLOG(3) << "Visiting HLO " << this;
+ TF_RETURN_IF_ERROR(AcceptInternalVisit(visitor));
+ visitor->SetVisited(*this);
+ return visitor->Postprocess(this);
+}
+
+Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit) {
+ auto status = AcceptInternal(visitor);
+ if (!status.ok()) {
+ return status;
+ }
+
+ if (call_finish_visit) {
+ return visitor->FinishVisit(this);
+ } else {
+ return Status::OK();
+ }
+}
+
+namespace {
+
+// Returns true if the given order is a topological sort of exactly those
+// instructions rooted at 'root'.
+bool OrderIsTopologicalSort(HloInstruction* root,
+ const std::vector<const HloInstruction*>& order) {
+ // Create a map from instruction to its position in 'order'.
+ std::unordered_map<const HloInstruction*, int> order_position;
+ for (int i = 0; i < order.size(); i++) {
+ if (!order_position.insert(std::make_pair(order[i], i)).second) {
+ // Instruction order[i] is duplicated in the order.
+ return false;
+ }
+ }
+ // Verify that the operand of each instruction in the order is also in the
+ // order *and* the operand's position is earlier (defs are before uses for all
+ // ops).
+ for (auto* instruction : order) {
+ for (auto* operand : instruction->operands()) {
+ if (order_position.count(operand) == 0 ||
+ order_position.at(operand) >= order_position.at(instruction)) {
+ return false;
+ }
+ }
+ }
+
+ // Create a vector of all instructions in a DFS search starting at
+ // root. 'order' should contain exactly these instructions.
+ std::vector<const HloInstruction*> visited;
+ TF_CHECK_OK(root->Accept([&visited](HloInstruction* instruction) {
+ visited.push_back(instruction);
+ return Status::OK();
+ }));
+
+ if (order_position.size() != visited.size()) {
+ return false;
+ }
+ for (auto* instruction : visited) {
+ if (order_position.count(instruction) == 0) {
+ return false;
+ }
+ }
+ // Given the conditions above, the last element of order should always be the
+ // root.
+ CHECK_EQ(root, order[order.size() - 1]);
+
+ return true;
+}
+
+} // namespace
+
+Status HloInstruction::Accept(FunctionVisitor::VisitorFunction visitor_func) {
+ FunctionVisitor visitor(visitor_func);
+ return this->Accept(&visitor);
+}
+
+Status HloInstruction::AcceptOrdered(
+ DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) {
+ DCHECK(OrderIsTopologicalSort(this, order));
+ for (auto* const_instruction : order) {
+ // The visitor can mark instructions as visited to skip particular
+ // instructions.
+ if (visitor->DidVisit(*const_instruction)) {
+ VLOG(3) << "Not visiting HLO " << const_instruction
+ << " as it was already visited.";
+ continue;
+ }
+
+ HloInstruction* instruction =
+ const_cast<HloInstruction*>(const_instruction);
+
+ TF_RETURN_IF_ERROR(visitor->Preprocess(instruction));
+ VLOG(3) << "Visiting HLO " << instruction;
+ TF_RETURN_IF_ERROR(instruction->AcceptInternalVisit(visitor));
+ visitor->SetVisited(*instruction);
+ TF_RETURN_IF_ERROR(visitor->Postprocess(instruction));
+ }
+
+ return visitor->FinishVisit(this);
+}
+
+const Shape& HloInstruction::shape() const {
+ TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
+ return shape_;
+}
+
+std::vector<int64> HloInstruction::OperandIndices(
+ const HloInstruction* operand) const {
+ std::vector<int64> result;
+ for (int64 i = 0; i < operand_count(); ++i) {
+ if (this->operand(i) == operand) {
+ result.push_back(i);
+ }
+ }
+ return result;
+}
+
+bool HloInstruction::IsElementwise() const {
+ switch (opcode_) {
+ // Nullary elementwise operations.
+ case HloOpcode::kConstant:
+ return true;
+
+ // Unary elementwise operations.
+ case HloOpcode::kAbs:
+ case HloOpcode::kCeil:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCopy:
+ case HloOpcode::kExp:
+ case HloOpcode::kFloor:
+ case HloOpcode::kLog:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kNegate:
+ case HloOpcode::kSign:
+ case HloOpcode::kTanh:
+ return true;
+
+ // Binary elementwise operations.
+ case HloOpcode::kAdd:
+ case HloOpcode::kDivide:
+ case HloOpcode::kEq:
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kPower:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalOr:
+ return true;
+
+ // Ternary elementwise operations.
+ case HloOpcode::kSelect:
+ return !ShapeUtil::IsTuple(shape_);
+ case HloOpcode::kClamp:
+ return true;
+
+ // Other operations.
+ case HloOpcode::kMap:
+ return true;
+ case HloOpcode::kFusion:
+ if (fusion_kind() != FusionKind::kLoop) {
+ return false;
+ }
+ for (auto& fused : fused_instructions()) {
+ if (fused->opcode() != HloOpcode::kParameter &&
+ !fused->IsElementwise()) {
+ return false;
+ }
+ }
+ return true;
+
+ default:
+ return false;
+ }
+}
+
+namespace {
+bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction,
+ const HloInstruction* operand) {
+ std::vector<int64> operand_indices = instruction->OperandIndices(operand);
+ return std::all_of(
+ operand_indices.begin(), operand_indices.end(),
+ [instruction](int64 operand_index) {
+ return instruction->IsElementwiseOnOperand(operand_index);
+ });
+}
+} // namespace
+
+bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
+ // For all instructions other than kFusion, being elementwise on one of the
+ // operands is equivalent to being elementwise on all the operands.
+ if (opcode() != HloOpcode::kFusion) {
+ return IsElementwise();
+ }
+
+ CHECK_EQ(HloOpcode::kFusion, opcode());
+ if (fusion_kind() != FusionKind::kLoop) {
+ return false;
+ }
+
+ // A loop-fusion is elementwise on an operand if all operations (computed
+ // using BFS) between the operand and the fused root are elementwise.
+ std::deque<HloInstruction*> worklist;
+ std::unordered_set<const HloInstruction*> visited;
+ worklist.push_back(fused_parameter(operand_idx));
+ visited.insert(fused_parameter(operand_idx));
+ while (!worklist.empty()) {
+ HloInstruction* operand = worklist.front();
+ worklist.pop_front();
+ for (HloInstruction* user : operand->users()) {
+ if (visited.count(user)) {
+ continue;
+ }
+ if (user->IsElementwise() ||
+ IsInstructionElementwiseOnOperand(user, operand)) {
+ worklist.push_back(user);
+ visited.insert(user);
+ } else {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
+ switch (opcode_) {
+ case HloOpcode::kBitcast:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kSlice:
+ case HloOpcode::kTranspose:
+ return UseKind::kUsePermutingElements;
+ case HloOpcode::kPad:
+ case HloOpcode::kReduce:
+ // Pad reuses the padding value but not the padded array elements.
+ // Reduce reuses the init value but not the operand array elements.
+ return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
+ case HloOpcode::kFusion: {
+ tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> cache;
+ // We could rather iterate backwards thru fused_instructions_ here, as it
+ // is in reverse postorder, and compute whether each fused instruction
+ // reuses the value of this parameter, which would save stack space but
+ // not allow us to finish early if we find a reuse.
+ std::function<UseKind(const HloInstruction&)> reuses_parameter_elements =
+ [i, &cache, &reuses_parameter_elements](const HloInstruction& hlo) {
+ auto plus = [](const UseKind& a, const UseKind& b) {
+ if (a == UseKind::kNoUse) return b;
+ if (b == UseKind::kNoUse) return a;
+ if (a == UseKind::kReuse || b == UseKind::kReuse) {
+ return UseKind::kReuse;
+ }
+ if (a == UseKind::kUsePermutingElements ||
+ b == UseKind::kUsePermutingElements) {
+ return UseKind::kReuse;
+ }
+ CHECK(UseKind::kUse == a && UseKind::kUse == b);
+ return UseKind::kUse;
+ };
+
+ if (hlo.opcode_ == HloOpcode::kParameter &&
+ hlo.parameter_number_ == i) {
+ return UseKind::kUse;
+ }
+ if (cache.count(&hlo) == 0) {
+ for (int64 j = 0; j < hlo.operands_.size(); ++j) {
+ UseKind old = cache[&hlo];
+ UseKind updated = plus(
+ old, std::min(hlo.OperandElementUse(j),
+ reuses_parameter_elements(*hlo.operand(j))));
+ cache[&hlo] = updated;
+ }
+ }
+ return cache[&hlo];
+ };
+ return reuses_parameter_elements(*fused_root_);
+ }
+ default:
+ return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
+ }
+}
+
+namespace {
+
+// Prereq: `order` is a permutation of {0, 1, ..., `dims.size()-1`}
+void Strip1SizedDimensions(tensorflow::protobuf::RepeatedField<int64>* dims,
+ std::vector<int64>* order) {
+ // We can't merely call StripDegenerateDimensions here as we must also delete
+ // the dimension indices.
+ for (size_t i = 0; i < dims->size(); ++i) {
+ if (1 == dims->Get(i)) {
+ dims->erase(dims->begin() + i);
+ // We must find this, as order must be a permutation of operand
+ // dimensions.
+ order->erase(std::find(order->begin(), order->end(), i));
+ }
+ }
+}
+
+} // namespace
+
+std::tuple<bool, std::vector<int64>, std::vector<int64>>
+HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
+ if (HloOpcode::kReshape != opcode_) {
+ return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
+ }
+ return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
+ shape_);
+}
+
+string FusionKindString(HloInstruction::FusionKind kind) {
+ switch (kind) {
+ case HloInstruction::FusionKind::kLoop:
+ return "Loop";
+ case HloInstruction::FusionKind::kInput:
+ return "Input";
+ case HloInstruction::FusionKind::kTransposeDot:
+ return "TransposeDot";
+ case HloInstruction::FusionKind::kConvBackwardFilter:
+ return "ConvBackwardFilter";
+ case HloInstruction::FusionKind::kConvBackwardInput:
+ return "ConvBackwardInput";
+ }
+}
+
+bool HloInstruction::CouldBeBitcast() const {
+ switch (opcode_) {
+ case HloOpcode::kTranspose:
+ return true;
+ case HloOpcode::kReshape:
+ return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
+ default:
+ return false;
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
new file mode 100644
index 0000000000..8e7a253578
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -0,0 +1,791 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// HLO instructions are in DAG form and represent the computations that the user
+// has built up via the XLA service interface. They are ultimately lowered
+// in a platform-aware way by traversing the HLO DAG and emitting a lowered
+// form; e.g. see DfsHloVisitor.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
+
+#include <functional>
+#include <list>
+#include <memory>
+#include <set>
+#include <string>
+#include <tuple>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class HloComputation;
+
+// HLO instructions are the IR used by the high-level compiler.
+class HloInstruction {
+ public:
+ enum class FusionKind {
+ kLoop, // Fused into a loop.
+ kInput, // Fused into a reduction kernel.
+ kTransposeDot, // Fused into a dot with transposed operands.
+ kConvBackwardFilter, // Fused into a backward filter convolution.
+ kConvBackwardInput, // Fused into a backward input convolution.
+ };
+
+ // Creates a parameter-retrieving instruction.
+ static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
+ const Shape& shape,
+ const string& name);
+
+ // Creates a literal constant instruction.
+ static std::unique_ptr<HloInstruction> CreateConstant(
+ std::unique_ptr<Literal> literal);
+
+ // Creates a get tuple element instruction.
+ static std::unique_ptr<HloInstruction> CreateGetTupleElement(
+ const Shape& shape, HloInstruction* operand, int64 index);
+
+ // Creates a trace instruction that logs the input operand in the computation.
+ static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
+ HloInstruction* operand);
+
+ // Creates a random number generation instruction that fills a shape with
+ // random numbers from a given distribution.
+ static std::unique_ptr<HloInstruction> CreateRng(
+ const Shape& shape, RandomDistribution distribution,
+ tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
+
+ // Creates an n-ary elementwise operation.
+ static std::unique_ptr<HloInstruction> CreateNary(
+ const Shape& shape, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
+ // Creates a unary instruction (one operand).
+ // Precondition: opcode must be a legitimate unary operation.
+ static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
+ HloOpcode opcode,
+ HloInstruction* operand);
+
+ // Creates a binary instruction (two operands).
+ // Precondition: opcode must be a legitimate binary operation.
+ static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
+ HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs);
+
+ // Creates a ternary instruction (three operands).
+ // Precondition: opcode must be a legitimate ternary operation.
+ static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
+ HloOpcode opcode,
+ HloInstruction* lhs,
+ HloInstruction* rhs,
+ HloInstruction* ehs);
+
+ // Creates a variadic instruction (variable number of operands).
+ // Precondition: opcode must be a legitimate variadic operation.
+ static std::unique_ptr<HloInstruction> CreateVariadic(
+ const Shape& shape, HloOpcode opcode,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
+ // Creates a map instruction, where the computation (given by the handle) is
+ // applied element-wise to every element in operands (across the operands,
+ // at a given index) with the same `static_operands`.
+ static std::unique_ptr<HloInstruction> CreateMap(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* map_computation,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
+
+ // Creates a convolution op, where rhs is the convolutional filter
+ // and window describes how the filter is applied to lhs.
+ static std::unique_ptr<HloInstruction> CreateConvolve(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Creates a cross replica sum op.
+ static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
+ const Shape& shape, HloInstruction* operand);
+
+ // Creates a conversion instruction, where operand is the data to convert and
+ // shape is the target shape for the conversion.
+ static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
+ HloInstruction* operand);
+
+ // Creates an infeed instruction, which reads data of the given shape from the
+ // Infeed interface of the device.
+ static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape);
+
+ // Creates a send instruction, which sends the operand data to a receive
+ // instruction in another computation.
+ static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand);
+
+ // Creates a receive instruction, which receives data of the given shape
+ // from a send instruction in another computation.
+ static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape);
+
+ // Creates a slice instruction, where the operand is sliced by the given
+ // start/limit indices.
+ static std::unique_ptr<HloInstruction> CreateSlice(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> start_indices,
+ tensorflow::gtl::ArraySlice<int64> limit_indices);
+
+ // Creates a slice instruction, where the first operand is sliced by
+ // start indices specified in the second operand, and by size specfied in
+ // 'slice_sizes'.
+ static std::unique_ptr<HloInstruction> CreateDynamicSlice(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* start_indices,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ // Creates a dynamic update slice instruction, which updates a slice
+ // of 'operand' with 'update' and 'start_indices'.
+ static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
+ const Shape& shape, HloInstruction* operand, HloInstruction* update,
+ HloInstruction* start_indices);
+
+ // Creates a concatenate instruction, where the operands are concatenated on
+ // the provided dimension.
+ static std::unique_ptr<HloInstruction> CreateConcatenate(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ int64 dimension);
+
+ // Creates a reduce instruction, where the computation (given by the handle)
+ // is applied successively to every element in operand. That is, if f is the
+ // function to apply (which either takes 2 [accumulator, value] or 3
+ // [accumulator, index, value] arguments) and init is a reduction operator
+ // specified initial value (for example, 0 for addition), then this operation
+ // will compute:
+ // f(f(init, [index0], value0), [index1], value1), ...)
+ static std::unique_ptr<HloInstruction> CreateReduce(
+ const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ HloComputation* reduce_computation);
+
+ // Creates a reduce-window instruction, where the computation (given
+ // by the handle) is applied window-wise at each valid window
+ // position in the operand.
+ static std::unique_ptr<HloInstruction> CreateReduceWindow(
+ const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
+ const Window& window, HloComputation* reduce_computation);
+
+ // Creates a scatter computation that scatters the `source` array to the
+ // selected indices of each window.
+ static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
+ const Shape& shape, HloInstruction* operand, HloComputation* select,
+ const Window& window, HloInstruction* source, HloInstruction* init_value,
+ HloComputation* scatter);
+
+ // Creates a broadcast instruction.
+ static std::unique_ptr<HloInstruction> CreateBroadcast(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Creates a pad instruction, where the operand is padded on the edges and
+ // between the elements with the given padding value.
+ static std::unique_ptr<HloInstruction> CreatePad(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* padding_value, const PaddingConfig& padding_config);
+
+ // Creates a reshape instruction, where the operand is flattened row-major
+ // order and then reshaped to the given result shape.
+ static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
+ HloInstruction* operand);
+
+ // Creates a transpose instruction which permutes the operand dimensions.
+ static std::unique_ptr<HloInstruction> CreateTranspose(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Creates a while instruction, given a condition computation, a body
+ // computation, and the initial value for the input of the computations. For
+ // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
+ // corresponds to the C code below.
+ // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
+ static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
+ HloComputation* condition,
+ HloComputation* body,
+ HloInstruction* init);
+
+ // Creates a fusion instruction. A fusion instruction contains one or more
+ // fused instructions forming an expression with a single root
+ // "fused_root". Additional instructions can be added to the fusion
+ // instruction with the method FuseInstruction.
+ static std::unique_ptr<HloInstruction> CreateFusion(
+ const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
+
+ // Creates a fusion instruction that represents backward convolution. This is
+ // similar to CreateFusion, but with extra arguments indicating the window and
+ // dimemsion mapping of the backward convolution.
+ static std::unique_ptr<HloInstruction> CreateFusionForBackwardConvolution(
+ const Shape& shape, FusionKind fusion_kind, const Window& window,
+ const ConvolutionDimensionNumbers& conv_dnums,
+ HloInstruction* fused_root);
+
+ // Creates a call instruction that applies the given computation on the given
+ // operands. "shape" is the resultant shape.
+ static std::unique_ptr<HloInstruction> CreateCall(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* computation);
+
+ // Creates a custom call instruction that applies the given custom call target
+ // to the given operands. "shape" is the resultant shape.
+ static std::unique_ptr<HloInstruction> CreateCustomCall(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece custom_call_target);
+
+ // Creates a tuple instruction with the given elements. This is a convenience
+ // wrapper around CreateVariadic.
+ static std::unique_ptr<HloInstruction> CreateTuple(
+ tensorflow::gtl::ArraySlice<HloInstruction*> elements);
+
+ // Creates a reverse instruction, which reverses the order of the elements
+ // in the specified dimensions.
+ static std::unique_ptr<HloInstruction> CreateReverse(
+ const Shape& shape, HloInstruction* operand,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Returns the opcode for this instruction.
+ HloOpcode opcode() const { return opcode_; }
+
+ // Returns the result shape of this instruction.
+ const Shape& shape() const;
+
+ // Returns the (mutable) result shape of this instruction.
+ Shape* mutable_shape() { return &shape_; }
+
+ // Returns the ith operand to this instruction.
+ const HloInstruction* operand(int64 i) const;
+
+ // Returns the ith operand to this instruction.
+ HloInstruction* mutable_operand(int64 i);
+
+ // Returns the number of operands to this instruction.
+ int64 operand_count() const { return operands_.size(); }
+
+ // Returns the vector of operands of this instruction.
+ const std::vector<HloInstruction*>& operands() const { return operands_; }
+
+ // Returns the index of 'target' in the operands sequence.
+ // Precondition: target must be an operand (or a fatal error will occur).
+ int64 operand_index(const HloInstruction* target) const;
+
+ // Returns the number of users of this instruction.
+ int64 user_count() const { return users_.size(); }
+
+ // Returns the users of this instruction.
+ const std::set<HloInstruction*>& users() const { return users_; }
+
+ // Returns the set of control predecessors of this instruction. Control
+ // predecessors are the instructions that must be scheduled before the
+ // current instruction.
+ const std::set<HloInstruction*>& control_predecessors() const {
+ return control_predecessors_;
+ }
+
+ // Adds the given instruction to the set of control predecessors.
+ void AddControlPredecessor(HloInstruction* instruction);
+
+ // Returns true if "other" performs the same computation as this instruction.
+ // Layout of the instructions' output array is not considered.
+ bool Identical(
+ const HloInstruction& other,
+ std::function<bool(const HloInstruction*, const HloInstruction*)>
+ eq_operands = std::equal_to<const HloInstruction*>(),
+ std::function<bool(const HloComputation*, const HloComputation*)>
+ eq_computations = std::equal_to<const HloComputation*>()) const;
+
+ // Returns whether the instruction has a constant operand.
+ bool HasConstantOperand() const;
+
+ // Returns whether this instruction does a rank-2 transposition.
+ bool IsRank2Transpose() const;
+
+ // Replaces the use of this instruction in "user" with "new_producer". Note
+ // that there might be multiple uses of this instruction in "user"; all will
+ // be replaced.
+ void ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
+
+ // Replaces the specified operand with new_operand.
+ void ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
+
+ // Replaces all uses of this instruction with the new producer. If
+ // new_producer is a user of this instruction then new_producer remains a use
+ // of this instruction to avoid introducing cycles into the graph.
+ void ReplaceAllUsesWith(HloInstruction* new_producer);
+
+ // Detaches an instruction from its operands. That is, remove the instruction
+ // from each operand's user set. This should only be called prior to
+ // deallocating the instruction.
+ void DetachFromOperands();
+
+ // Performs a postorder DFS visit using this node as the root. If
+ // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
+ // complete.
+ Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true);
+
+ // Performs a postorder DFS visit using this node as the root. Calls the given
+ // visitor function at each instruction.
+ Status Accept(FunctionVisitor::VisitorFunction visitor_func);
+
+ // Visits all instructions rooted at this instruction using the given visitor
+ // in the given order. 'order' must contain exactly the set of instructions
+ // rooted at this node (ie, those accessible from a DFS traversal from this
+ // instruction). 'order' must also be a valid topological sort of these
+ // instructions (defs appear before uses).
+ Status AcceptOrdered(DfsHloVisitor* visitor,
+ const std::vector<const HloInstruction*>& order);
+
+ // Returns the literal associated with this instruction.
+ //
+ // Note: only constant and parameter opcodes have an associated literal.
+ const Literal& literal() const;
+
+ // Returns the parameter number associated with this instruction.
+ //
+ // Note: only parameter opcodes have an associated parameter number.
+ int64 parameter_number() const {
+ CHECK_EQ(HloOpcode::kParameter, opcode_);
+ return parameter_number_;
+ }
+
+ const string& parameter_name() const {
+ CHECK_EQ(HloOpcode::kParameter, opcode_);
+ return parameter_name_;
+ }
+
+ // Returns the dimension sizes or numbers associated with this instruction.
+ //
+ // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
+ // and reverse.
+ const std::vector<int64>& dimensions() const;
+ int64 dimensions(int64 index) const;
+
+ // Accessor for the dimension in which a concatenate HLO should occur.
+ // Precondition: opcode() == HloOpcode::kConcatenate
+ int64 concatenate_dimension() const;
+
+ // Returns the tuple index associated with this instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kGetTupleElement
+ int64 tuple_index() const;
+
+ // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
+ // The setter should only be called by HloModule or HloComputation methods.
+ //
+ // Precondition: The instruction has a valid to_apply_ field.
+ HloComputation* to_apply() const;
+ void set_to_apply(HloComputation* to_apply);
+
+ // Returns the custom_call_target for CustomCall.
+ // Precondition: opcode() == HloOpcode::kCustomCall
+ const string& custom_call_target() const;
+
+ // Gets/sets the while_condition or while_body HloComputation for While. The
+ // setters should only be called by HloModule or HloComputation methods.
+ //
+ // Precondition: The instruction is a While instruction.
+ HloComputation* while_condition() const;
+ HloComputation* while_body() const;
+ void set_while_condition(HloComputation* while_condition);
+ void set_while_body(HloComputation* while_body);
+
+ // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
+ // setters should only be called by HloModule or HloComputation methods.
+ //
+ // Precondition: opcode() == HloOpcode::kSelectAndScatter.
+ HloComputation* select() const;
+ HloComputation* scatter() const;
+ void set_select(HloComputation* select);
+ void set_scatter(HloComputation* scatter);
+
+ // Returns a string for the signature of this instruction if considered as a
+ // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
+ string SignatureString() const;
+
+ // Returns a debugging string that represents this instruction.
+ string ToString() const;
+
+ // As ToString, but returns a shorter string.
+ string ToShortString() const;
+
+ // Returns a logging instruction, if the output of this instruction is logged.
+ //
+ // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
+ HloInstruction* tracing() const;
+ void set_tracing(HloInstruction* trace_instruction);
+
+ // Returns the channel id associated with the instruction. The id is
+ // shared between each Send/Recv pair and is globally unique to identify each
+ // channel.
+ //
+ // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
+ int64 channel_id() const { return channel_id_; }
+ void set_channel_id(int64 id) { channel_id_ = id; }
+
+ // Returns a tag to be used in tracing.
+ //
+ // Precondition: opcode() == HloOpcode::kTrace
+ const string& tracing_tag() const;
+
+ // Returns whether the instruction is a constant.
+ bool IsConstant() const;
+
+ // Returns true if this instruction is fused, ie contained within a fusion
+ // instruction.
+ bool IsFused() const;
+
+ // Returns true if this instruction can be legally fused into a fusion
+ // instruction.
+ bool IsFusable() const;
+
+ // Returns the fusion instruction that contains this instruction.
+ //
+ // Note: only valid if this instruction is fused into a fusion instruction.
+ HloInstruction* fusion_instruction() const;
+
+ // Returns the root instruction of the fused expression contained within this
+ // fusion instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ HloInstruction* fused_expression_root() const;
+
+ // Returns the vector of fused instructions inside this fusion
+ // instruction. The order is a reverse postorder of the fused expression (root
+ // is first in the order).
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ const std::list<std::unique_ptr<HloInstruction>>& fused_instructions() const;
+
+ // Returns the fused parameter instruction in this fusion instruction
+ // corresponding to the given parameter number.
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ HloInstruction* fused_parameter(int64 parameter_number) const;
+
+ FusionKind fusion_kind() const {
+ CHECK_EQ(HloOpcode::kFusion, opcode_);
+ return fusion_kind_;
+ }
+
+ // Fuses the given instruction in this fusion instruction. instruction_to_fuse
+ // is cloned and the clone is placed in the fusion
+ // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
+ // than moved to cleanly handle the case where the instruction has a use
+ // outside the fusion instruction. Moving such an instruction into a fusion
+ // instruction would violate the single-result invariant of HLO instructions
+ // and significantly complicate code generation.
+ //
+ // Precondition: this->opcode() == HloOpcode::kFusion
+ HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
+
+ // Returns the start index in the given dimension for a slice node.
+ //
+ // Precondition: opcode() == HloOpcode::kSlice
+ int64 slice_starts(int64 dimension) const {
+ CHECK_EQ(HloOpcode::kSlice, opcode_);
+ return slice_starts_[dimension];
+ }
+ const std::vector<int64>& slice_starts() const { return slice_starts_; }
+
+ // Returns the (exclusive) limit index in the given dimension for a slice
+ // node.
+ //
+ // Precondition: opcode() == HloOpcode::kSlice
+ int64 slice_limits(int64 dimension) const {
+ CHECK_EQ(HloOpcode::kSlice, opcode_);
+ return slice_limits_[dimension];
+ }
+ const std::vector<int64>& slice_limits() const {
+ CHECK_EQ(HloOpcode::kSlice, opcode_);
+ return slice_limits_;
+ }
+
+ // Returns the size of the slice in the given dimension for a dynamic
+ // slice node.
+ //
+ // Precondition: opcode() == HloOpcode::kDynamicSlice
+ int64 slice_sizes(int64 dimension) const {
+ CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
+ return dynamic_slice_sizes_[dimension];
+ }
+ const std::vector<int64>& dynamic_slice_sizes() const {
+ CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
+ return dynamic_slice_sizes_;
+ }
+
+ // Returns data on the window in a windowed operation such as
+ // convolution.
+ const Window& window() const {
+ CHECK(window_ != nullptr);
+ return *window_;
+ }
+
+ // Returns the padding configuration for a pad node.
+ //
+ // Precondition: opcode() == HloOpcode::kPad
+ const PaddingConfig& padding_config() const {
+ CHECK(padding_config_ != nullptr);
+ return *padding_config_;
+ }
+
+ // Returns data on the dimension numbers used for a convolution
+ // operation.
+ const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
+ CHECK(convolution_dimension_numbers_ != nullptr);
+ return *convolution_dimension_numbers_;
+ }
+
+ // Returns the random distribution for this rng node.
+ //
+ // Precondition: opcode() == HloOpcode::kRng
+ RandomDistribution random_distribution() const;
+
+ // Clones the HLO instruction. The clone will have the same opcode, shape, and
+ // operands. After creation the clone has no uses. "this" (the instruction
+ // cloned from) is not changed.
+ std::unique_ptr<HloInstruction> Clone();
+
+ // Clones the HLO instruction as above but with new shape and operands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperands(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
+ // Computes and returns the computations this instruction calls (if any). This
+ // includes computations called by fused instructions inside of a fusion
+ // instruction.
+ std::set<HloComputation*> MakeCalledComputationsSet() const;
+
+ // Returns true if this instruction performs an elementwise operation on
+ // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
+ // after performing necessary implicit broadcast
+ // (cs/IrArray::EmitArrayElementAddress), to compute the output at index
+ // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is
+ // the element at {i_0,i_1,...,i_n}.
+ //
+ // Note on performance: when this instruction is kFusion, this method, in the
+ // worst case, scans all fused instructions. We could speed this up by
+ // caching.
+ bool IsElementwiseOnOperand(int64 operand_idx) const;
+
+ // Returns true if this instruction is elementwise on all its operands.
+ bool IsElementwise() const;
+
+ // Returns whether this instruction may reuse elements of its `i`th operand.
+ bool ReusesOperandElements(int64 i) const {
+ return OperandElementUse(i) == UseKind::kReuse;
+ }
+
+ // Returns the indices that the given operand appear in the operand list of
+ // this instruction. Note that an instruction can use the same operand
+ // multiple times.
+ std::vector<int64> OperandIndices(const HloInstruction* operand) const;
+
+ // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
+ // this reshape merely inserts or deletes 1-sized dimensions, return the input
+ // indices of the deleted dimensions and the output indices of the inserted
+ // dimensions.
+ //
+ // Precondition: this op must be a reshape.
+ std::tuple<bool, std::vector<int64>, std::vector<int64>>
+ ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
+
+ // Returns a string identifier for this instruction. If no string identifier
+ // has been explicitly set, then the identifier is the serialized pointer to
+ // this instruction.
+ const string& name() const { return name_; }
+
+ // Sets the string identifier for this instruction.
+ void set_name(const string& name) { name_ = name; }
+
+ // Set/get the computation containing this instruction. set_parent should only
+ // be called by HloComputation methods which add/remove instructions to
+ // computations.
+ void set_parent(HloComputation* computation) { parent_ = computation; }
+ const HloComputation* parent() const { return parent_; }
+ HloComputation* parent() { return parent_; }
+
+ // Returns whether we could assign input and output layouts to this
+ // instruction to make it a bitcast.
+ bool CouldBeBitcast() const;
+
+ private:
+ enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
+
+ // Appends operand to the list of operands and adds this instruction as a user
+ // of the operand.
+ void AppendOperand(HloInstruction* operand);
+
+ // Adds a user for this instruction.
+ void AddUser(HloInstruction* user);
+
+ // Removes a user for this instruction.
+ void RemoveUser(HloInstruction* user);
+
+ // Internal constructor for a given opcode/shape, other fields must be filled
+ // by factory methods.
+ HloInstruction(HloOpcode opcode, const Shape& shape);
+
+ // Clones the given instruction_to_fuse and insert the clone into this fusion
+ // instruction.
+ //
+ // Precondition: opcode() == HloOpcode::kFusion
+ HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse);
+
+ // Clones a fusion instruction with a new shape and operands.
+ std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+
+ // Inner DFS traversal function -- this function being called (rather than
+ // Accept above) allows us to distinguish the root of the traversal.
+ Status AcceptInternal(DfsHloVisitor* visitor);
+
+ // Inner DFS traversal function called when visiting this HloInstruction.
+ Status AcceptInternalVisit(DfsHloVisitor* visitor);
+
+ // CHECKs various invariants of a fusion instruction.
+ void CheckFusionInstruction() const;
+
+ // Returns true if this instruction can legally have the dimensions field
+ // set. Used for checking precondition of dimensions field accessors.
+ bool CanHaveDimensionsField() const;
+
+ // Returns how this instruction uses elements of its `i`th operand.
+ UseKind OperandElementUse(int64 i) const;
+
+ // Result shape of this instruction.
+ Shape shape_;
+
+ // Opcode for this instruction.
+ HloOpcode opcode_;
+
+ // Literal, only present for kConstant.
+ std::unique_ptr<Literal> literal_;
+
+ // Constant index, only present for kGetTupleElement.
+ int64 tuple_index_ = 0;
+
+ // Dimensions present for some operations that require reshaping or
+ // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
+ std::vector<int64> dimensions_;
+
+ // Describes the window in a windowed operation such as convolution.
+ std::unique_ptr<Window> window_;
+
+ // Describes the dimension numbers used for a convolution.
+ std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
+
+ // Describes the [begin, end) index range for a slice.
+ std::vector<int64> slice_starts_;
+ std::vector<int64> slice_limits_;
+
+ // Describes the [start, start + size) range size for a dynamic slice
+ // ('start' is specified dynamically in the second operand of the operation).
+ std::vector<int64> dynamic_slice_sizes_;
+
+ // The padding configuration that describes the edge padding and interior
+ // padding of this pad instruction. Only set for pad instructions.
+ std::unique_ptr<PaddingConfig> padding_config_;
+
+ // The set of instruction fused into this fusion instruction. Only set for
+ // fusion instructions.
+ std::list<std::unique_ptr<HloInstruction>> fused_instructions_;
+
+ // If this instruction is fused into a fusion instruction, this field points
+ // to the fusion instruction.
+ HloInstruction* parent_fusion_instruction_ = nullptr;
+
+ // The vector of parameter instructions inside this fusion instruction. The
+ // index of the vector is the parameter_number of the parameter instruction.
+ // This vector is non-empty only for fusion instructions.
+ std::vector<HloInstruction*> fused_parameters_;
+
+ // The root of the expression fused into this fusion instruction.
+ HloInstruction* fused_root_ = nullptr;
+
+ // The type of the fusion. Used by kFusion only.
+ FusionKind fusion_kind_;
+
+ // For parameter instructions this field holds the parameter number.
+ int64 parameter_number_ = 0;
+ string parameter_name_;
+
+ // Computation to apply, only present for kCall, kMap, kReduce and
+ // kReduceWindow.
+ HloComputation* to_apply_ = nullptr;
+
+ // Name of a global symbol to call, only present for kCustomCall.
+ string custom_call_target_;
+
+ // Computation for condition and body of kWhile, only present for kWhile.
+ HloComputation* condition_ = nullptr;
+ HloComputation* body_ = nullptr;
+
+ // Computation for select and scatter, only present for
+ // kSelectAndScatter.
+ HloComputation* select_ = nullptr;
+ HloComputation* scatter_ = nullptr;
+
+ // Instruction operands.
+ std::vector<HloInstruction*> operands_;
+
+ // The users of this instruction. Users are HLOs where this instruction is an
+ // operand.
+ std::set<HloInstruction*> users_;
+
+ // The set of control predecessors of this instruction.
+ std::set<HloInstruction*> control_predecessors_;
+
+ // A trace instruction that consumes this instruction.
+ //
+ // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
+ // an operand.
+ HloInstruction* trace_instruction_ = nullptr;
+
+ // The distribution requested for random number generation.
+ // Only present for kRng.
+ RandomDistribution distribution_;
+
+ // Represents a unique identifier for each Send/Recv instruction pair.
+ // Only present for kSend or kRecv.
+ int64 channel_id_ = -1;
+
+ // String identifier for instruction.
+ string name_;
+
+ // The computation in which this instruction is contained.
+ HloComputation* parent_ = nullptr;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
+};
+
+string FusionKindString(HloInstruction::FusionKind kind);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
new file mode 100644
index 0000000000..41164a2d58
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -0,0 +1,894 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+#include <set>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+namespace {
+
+#define EXPECT_ISET(A, E...) EXPECT_EQ(A, (std::set<HloInstruction*>{E}))
+#define EXPECT_IVEC(A, E...) EXPECT_EQ(A, (std::vector<HloInstruction*>{E}))
+
+class HloInstructionTest : public ::testing::Test {
+ protected:
+ HloInstructionTest() {}
+
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+};
+
+// Simple visitor that collects the number of users and operands for certain HLO
+// nodes. It also verifies some of the DFS visiting invariants (operands visited
+// before their users, nodes not visited twice, etc.)
+class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault {
+ public:
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ return Unimplemented("not implemented %s",
+ HloOpcodeString(hlo_instruction->opcode()).c_str());
+ }
+
+ Status HandleParameter(HloInstruction* parameter) override {
+ EXPECT_EQ(0, count_.count(parameter));
+ count_[parameter] = GetCountsForNode(parameter);
+ return Status::OK();
+ }
+
+ Status HandleConstant(HloInstruction* constant,
+ const Literal& literal) override {
+ EXPECT_EQ(0, count_.count(constant));
+ count_[constant] = GetCountsForNode(constant);
+ return Status::OK();
+ }
+
+ Status HandleAdd(HloInstruction* add, HloInstruction* lhs,
+ HloInstruction* rhs) override {
+ EXPECT_EQ(0, count_.count(add));
+ EXPECT_GT(count_.count(lhs), 0);
+ EXPECT_GT(count_.count(rhs), 0);
+ count_[add] = GetCountsForNode(add);
+ return Status::OK();
+ }
+
+ Status HandleNegate(HloInstruction* negate,
+ HloInstruction* operand) override {
+ EXPECT_EQ(0, count_.count(negate));
+ EXPECT_GT(count_.count(operand), 0);
+ count_[negate] = GetCountsForNode(negate);
+ return Status::OK();
+ }
+
+ Status HandleMap(
+ HloInstruction* map,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* /*function*/,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/)
+ override {
+ EXPECT_EQ(0, count_.count(map));
+ for (HloInstruction* arg : operands) {
+ EXPECT_GT(count_.count(arg), 0);
+ }
+ count_[map] = GetCountsForNode(map);
+ return Status::OK();
+ }
+
+ Status HandleReduce(HloInstruction* reduce, HloInstruction* arg,
+ HloInstruction* init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ HloComputation* function) override {
+ EXPECT_EQ(0, count_.count(reduce));
+ EXPECT_GT(count_.count(arg), 0);
+ EXPECT_GT(count_.count(init_value), 0);
+ count_[reduce] = GetCountsForNode(reduce);
+ return Status::OK();
+ }
+
+ int64 NumOperands(const HloInstruction* node) {
+ auto count_iterator = count_.find(node);
+ EXPECT_NE(count_.end(), count_iterator);
+ return count_iterator->second.operand_count;
+ }
+
+ int64 NumUsers(const HloInstruction* node) {
+ auto count_iterator = count_.find(node);
+ EXPECT_NE(count_.end(), count_iterator);
+ return count_iterator->second.user_count;
+ }
+
+ private:
+ struct NumOpsAndUsers {
+ int64 operand_count;
+ int64 user_count;
+ };
+
+ // Helper function to count operands and users for the given HLO.
+ NumOpsAndUsers GetCountsForNode(const HloInstruction* node) {
+ NumOpsAndUsers counts{node->operand_count(), node->user_count()};
+ return counts;
+ }
+
+ // Counters for HLOs. Maps HLO to a NumOpsAndUsers.
+ std::unordered_map<const HloInstruction*, NumOpsAndUsers> count_;
+};
+
+TEST_F(HloInstructionTest, BasicProperties) {
+ auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
+
+ EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
+ EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape()));
+ EXPECT_EQ(0, parameter->operand_count());
+}
+
+TEST_F(HloInstructionTest, UserWithTwoOperands) {
+ // [Param foo]-----> |-----|
+ // | Add |
+ // [Param bar]-----> |-----|
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+ auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(),
+ bar.get());
+ EXPECT_MATCH(add->operands(), testing::UnorderedMatcher<HloInstruction*>(
+ foo.get(), bar.get()));
+ EXPECT_ISET(foo->users(), add.get());
+ EXPECT_ISET(bar->users(), add.get());
+
+ OpAndUserCollectingVisitor visitor;
+ ASSERT_IS_OK(add->Accept(&visitor));
+
+ EXPECT_EQ(2, visitor.NumOperands(add.get()));
+ EXPECT_EQ(0, visitor.NumUsers(add.get()));
+ EXPECT_EQ(1, visitor.NumUsers(foo.get()));
+ EXPECT_EQ(1, visitor.NumUsers(bar.get()));
+}
+
+TEST_F(HloInstructionTest, MultipleUsers) {
+ // [Param foo]
+ // / | \
+ // / | \ [Param bar]
+ // / | \ |
+ // | | | |
+ // V V V V
+ // ------- ------- -----------
+ // | exp | | exp | | add |
+ // ------- ------- -----------
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+ auto exp1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get());
+ auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get());
+ auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(),
+ bar.get());
+
+ EXPECT_EQ(3, foo->user_count());
+ EXPECT_EQ(1, bar->user_count());
+ EXPECT_EQ(0, exp1->user_count());
+ EXPECT_EQ(0, exp2->user_count());
+ EXPECT_EQ(0, add->user_count());
+
+ OpAndUserCollectingVisitor visitor;
+ ASSERT_IS_OK(add->Accept(&visitor));
+
+ EXPECT_EQ(2, visitor.NumOperands(add.get()));
+ EXPECT_EQ(3, visitor.NumUsers(foo.get()));
+}
+
+TEST_F(HloInstructionTest, RepeatedUser) {
+ // Here we have a user 'add' nodes that uses the same HLO in both operands.
+ // Make sure we don't count it as two distinct users.
+ //
+ // [Param foo]
+ // | |
+ // | |
+ // | |
+ // V V
+ // -------
+ // | add |
+ // -------
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, foo.get(),
+ foo.get());
+ EXPECT_EQ(1, foo->user_count());
+
+ // But 'add' still has two operands, even if both are the same HLO.
+ EXPECT_EQ(2, add->operand_count());
+}
+
+TEST_F(HloInstructionTest, MultipleUsersAndOperands) {
+ // [param0] [param1]
+ // | |
+ // | [c0] |
+ // | | |
+ // V | V
+ // ------- | -------
+ // | add | <---^---> | add |
+ // ------- -------
+ // | |
+ // \ ------- /
+ // ---->| add |<----
+ // -------
+ auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0");
+ auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1");
+ auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ param0.get(), c0.get());
+ auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ c0.get(), param1.get());
+ auto addtotal = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ addleft.get(), addright.get());
+
+ OpAndUserCollectingVisitor visitor;
+ ASSERT_IS_OK(addtotal->Accept(&visitor));
+
+ EXPECT_EQ(2, visitor.NumUsers(c0.get()));
+ EXPECT_EQ(2, visitor.NumOperands(addleft.get()));
+ EXPECT_EQ(2, visitor.NumOperands(addright.get()));
+ EXPECT_EQ(2, visitor.NumOperands(addtotal.get()));
+}
+
+TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) {
+ // [param0] [c0] [param1]
+ // | | |
+ // | V |
+ // | ------- |
+ // | | neg | |
+ // | ------- |
+ // V | V
+ // ------- | -------
+ // | add | <---^---> | add |
+ // ------- -------
+ // | |
+ // \ ------- /
+ // ---->| add |<----
+ // -------
+ // |
+ // V
+ // -------
+ // | neg |
+ // -------
+ auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0");
+ auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1");
+ auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto neg1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0.get());
+ auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ param0.get(), neg1.get());
+ auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ neg1.get(), param1.get());
+ auto addtotal = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ addleft.get(), addright.get());
+ auto neg2 =
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, addtotal.get());
+
+ OpAndUserCollectingVisitor visitor;
+ ASSERT_IS_OK(neg2->Accept(&visitor));
+
+ EXPECT_EQ(1, visitor.NumUsers(c0.get()));
+ EXPECT_EQ(2, visitor.NumUsers(neg1.get()));
+ EXPECT_EQ(2, visitor.NumOperands(addleft.get()));
+ EXPECT_EQ(2, visitor.NumOperands(addright.get()));
+ EXPECT_EQ(2, visitor.NumOperands(addtotal.get()));
+ EXPECT_EQ(1, visitor.NumOperands(neg2.get()));
+ EXPECT_EQ(0, visitor.NumUsers(neg2.get()));
+}
+
+TEST_F(HloInstructionTest, TrivialMap) {
+ // This tests creating a trivial x+1 map as the only operation.
+ //
+ // param0[100x10] ---> (map x+1)
+ //
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
+
+ // Builds an x+1.0 computation to use in a Map.
+ auto builder = HloComputation::Builder("f32+1");
+ auto param =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x"));
+ auto value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value));
+ auto add_f32 = builder.Build();
+
+ // Builds a parameter and feeds it to the map.
+ auto param0 = HloInstruction::CreateParameter(1, f32a100x10, "");
+ auto map =
+ HloInstruction::CreateMap(f32a100x10, {param0.get()}, add_f32.get());
+
+ OpAndUserCollectingVisitor visitor;
+ ASSERT_IS_OK(map->Accept(&visitor));
+
+ // Check counts. We aren't walking the mapper computation yet.
+ EXPECT_EQ(1, visitor.NumUsers(param0.get()));
+ EXPECT_EQ(0, visitor.NumUsers(map.get()));
+ EXPECT_EQ(1, visitor.NumOperands(map.get()));
+
+ // TODO(dehnert): Add walking and counters for the wrapped computation.
+}
+
+TEST_F(HloInstructionTest, TrivialReduce) {
+ // This tests creating a trivial x+y reduce as the only operation.
+ //
+ // param0[100x10] ---> (reduce x+y)
+ //
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ Shape f32v100 = ShapeUtil::MakeShape(F32, {100});
+ Shape f32a100x10 = ShapeUtil::MakeShape(F32, {100, 10});
+
+ // Builds an x+y computation to use in a Reduce.
+ auto builder = HloComputation::Builder("f32+f32");
+ auto paramx =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x"));
+ auto paramy =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, paramx, paramy));
+ auto add_f32 = builder.Build();
+
+ // Builds a parameter and an initial value and feeds them to the reduce.
+ auto param0 = HloInstruction::CreateParameter(0, f32a100x10, "");
+ auto const0 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f));
+ auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto reduce =
+ HloInstruction::CreateReduce(f32v100, param0.get(), const0.get(),
+ /*dimensions_to_reduce=*/{1}, add_f32.get());
+
+ OpAndUserCollectingVisitor visitor;
+ ASSERT_IS_OK(reduce->Accept(&visitor));
+
+ // Check counts. We aren't walking the reducer computation.
+ EXPECT_EQ(1, visitor.NumUsers(param0.get()));
+ EXPECT_EQ(1, visitor.NumUsers(const0.get()));
+ EXPECT_EQ(0, visitor.NumUsers(reduce.get()));
+ EXPECT_EQ(2, visitor.NumOperands(reduce.get()));
+}
+
+TEST_F(HloInstructionTest, ReplaceUseInBinaryOps) {
+ // Construct a graph of a few binary ops using two different
+ // parameters. Replace one of the parameters with the other parameter in one
+ // of the instructions.
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+ auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ foo.get(), bar.get());
+ auto add_foofoo = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ foo.get(), foo.get());
+ auto sum = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ add_foobar.get(), add_foofoo.get());
+
+ EXPECT_EQ(2, foo->user_count());
+ EXPECT_EQ(1, bar->user_count());
+
+ // Replace the use of foo in add_foofoo with bar.
+ foo->ReplaceUseWith(add_foofoo.get(), bar.get());
+
+ EXPECT_EQ(1, foo->user_count());
+ EXPECT_EQ(2, bar->user_count());
+
+ EXPECT_ISET(foo->users(), add_foobar.get());
+ EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get());
+
+ EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get());
+ EXPECT_IVEC(add_foobar->operands(), foo.get(), bar.get());
+ EXPECT_IVEC(add_foofoo->operands(), bar.get(), bar.get());
+}
+
+TEST_F(HloInstructionTest, ReplaceUseInVariadicOp) {
+ // Construct a tuple containing several parameters. Replace one parameter with
+ // another in the tuple.
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+ auto baz = HloInstruction::CreateParameter(2, r0f32_, "baz");
+
+ auto tuple =
+ HloInstruction::CreateTuple({foo.get(), bar.get(), baz.get(), foo.get()});
+ auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ foo.get(), bar.get());
+
+ EXPECT_EQ(2, foo->user_count());
+ EXPECT_ISET(foo->users(), tuple.get(), add_foobar.get());
+
+ // Replace the use of foo in tuple with bar.
+ foo->ReplaceUseWith(tuple.get(), bar.get());
+
+ EXPECT_ISET(foo->users(), add_foobar.get());
+
+ // Both uses of foo in tuple should have been replaced with bar.
+ EXPECT_IVEC(tuple->operands(), bar.get(), bar.get(), baz.get(), bar.get());
+}
+
+TEST_F(HloInstructionTest, ReplaceUseInUnaryOp) {
+ // Construct a couple unary instructions which use a parameter. Replace the
+ // use of a parameter in one of the unary ops with the other parameter.
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+
+ auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get());
+ auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get());
+
+ EXPECT_EQ(2, foo->user_count());
+ EXPECT_ISET(foo->users(), exp.get(), log.get());
+ EXPECT_EQ(0, bar->user_count());
+
+ // Replace the use of foo in exp with bar.
+ foo->ReplaceUseWith(exp.get(), bar.get());
+
+ // The use of foo in log should not have been affected.
+ EXPECT_EQ(1, foo->user_count());
+ EXPECT_ISET(foo->users(), log.get());
+ EXPECT_IVEC(log->operands(), foo.get());
+
+ // Bar should now be used in exp.
+ EXPECT_EQ(1, bar->user_count());
+ EXPECT_EQ(*bar->users().begin(), exp.get());
+ EXPECT_EQ(1, exp->operands().size());
+ EXPECT_EQ(*exp->operands().begin(), bar.get());
+}
+
+TEST_F(HloInstructionTest, ReplaceAllUsesWithInBinaryOps) {
+ // Construct a simple graph of a few binary ops using two different
+ // parameters. Replace all uses of one of the parameters with the other
+ // parameter.
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+ auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ foo.get(), bar.get());
+ auto add_foofoo = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ foo.get(), foo.get());
+ auto sum = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ add_foobar.get(), add_foofoo.get());
+
+ EXPECT_EQ(2, foo->user_count());
+ EXPECT_EQ(1, bar->user_count());
+
+ // Replace all uses of foo with bar.
+ foo->ReplaceAllUsesWith(bar.get());
+
+ EXPECT_EQ(0, foo->user_count());
+ EXPECT_EQ(2, bar->user_count());
+
+ EXPECT_ISET(bar->users(), add_foobar.get(), add_foofoo.get());
+}
+
+TEST_F(HloInstructionTest, ReplaceAllUsesInMultipleOps) {
+ // Construct a graph containing several ops (a unary, binary, and variadic)
+ // which use two parameters. Replace all uses of one of the parameters with
+ // the other parameter.
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto bar = HloInstruction::CreateParameter(1, r0f32_, "bar");
+
+ auto add_foobar = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ foo.get(), bar.get());
+ auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get());
+ auto tuple = HloInstruction::CreateTuple({foo.get(), bar.get()});
+
+ EXPECT_EQ(3, foo->user_count());
+ EXPECT_EQ(2, bar->user_count());
+
+ // Replace all uses of foo with bar.
+ foo->ReplaceAllUsesWith(bar.get());
+
+ EXPECT_EQ(0, foo->user_count());
+ EXPECT_EQ(3, bar->user_count());
+
+ EXPECT_ISET(bar->users(), add_foobar.get(), exp.get(), tuple.get());
+}
+
+// Simple visitor that collects and post-processes each node in the graph.
+class NodeCollectorAndPostProcessor : public DfsHloVisitorWithDefault {
+ public:
+ NodeCollectorAndPostProcessor() {}
+
+ Status Postprocess(HloInstruction* hlo) override {
+ post_processed_nodes_.push_back(hlo);
+ return Status::OK();
+ }
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override {
+ visited_nodes_.push_back(hlo_instruction);
+ return Status::OK();
+ }
+
+ const std::vector<const HloInstruction*>& visited_nodes() {
+ return visited_nodes_;
+ }
+
+ const std::vector<const HloInstruction*>& post_processed_nodes() {
+ return post_processed_nodes_;
+ }
+
+ private:
+ std::vector<const HloInstruction*> visited_nodes_;
+ std::vector<const HloInstruction*> post_processed_nodes_;
+};
+
+// Returns true if "vec" contains distinct nodes.
+bool Distinct(const std::vector<const HloInstruction*>& vec) {
+ std::set<const HloInstruction*> distinct_nodes(vec.begin(), vec.end());
+ return distinct_nodes.size() == vec.size();
+}
+
+TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) {
+ // Verifies all the nodes are visited and post-processed in the same order,
+ // and that each node is visited exactly once.
+ //
+ // /--> exp --\
+ // foo add
+ // \--> log --/
+ auto foo = HloInstruction::CreateParameter(0, r0f32_, "foo");
+ auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, foo.get());
+ auto log = HloInstruction::CreateUnary(r0f32_, HloOpcode::kLog, foo.get());
+ auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, exp.get(),
+ log.get());
+
+ NodeCollectorAndPostProcessor visitor;
+ ASSERT_IS_OK(add->Accept(&visitor));
+ // Verifies all the nodes are visited and post-processed in the same order.
+ EXPECT_EQ(visitor.visited_nodes(), visitor.post_processed_nodes());
+ // Verifies each node is visited exactly once.
+ EXPECT_TRUE(Distinct(visitor.visited_nodes()));
+}
+
+TEST_F(HloInstructionTest, SingletonFusionOp) {
+ // Create a fusion instruction containing a single unary operation.
+ auto constant =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto exp =
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
+
+ auto fusion = HloInstruction::CreateFusion(
+ r0f32_, HloInstruction::FusionKind::kLoop, exp.get());
+
+ EXPECT_IVEC(fusion->operands(), constant.get());
+ EXPECT_ISET(constant->users(), fusion.get(), exp.get());
+}
+
+TEST_F(HloInstructionTest, BinaryFusionOp) {
+ // Create a fusion instruction containing a single binary operation.
+ auto constant1 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto constant2 =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.1f));
+ auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
+ constant1.get(), constant2.get());
+
+ auto fusion = HloInstruction::CreateFusion(
+ r0f32_, HloInstruction::FusionKind::kLoop, add.get());
+
+ EXPECT_IVEC(fusion->operands(), constant1.get(), constant2.get());
+ EXPECT_ISET(constant1->users(), fusion.get(), add.get());
+ EXPECT_ISET(constant2->users(), fusion.get(), add.get());
+}
+
+TEST_F(HloInstructionTest, ChainFusionOp) {
+ // Create a chain of fused unary ops.
+ auto constant =
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto exp1 =
+ HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get());
+ auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get());
+ auto exp3 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp2.get());
+
+ auto fusion = HloInstruction::CreateFusion(
+ r0f32_, HloInstruction::FusionKind::kLoop, exp3.get());
+ fusion->FuseInstruction(exp2.get());
+ fusion->FuseInstruction(exp1.get());
+
+ EXPECT_IVEC(fusion->operands(), constant.get());
+ EXPECT_ISET(constant->users(), fusion.get(), exp1.get());
+}
+
+TEST_F(HloInstructionTest, ComplexFusionOp) {
+ // Fuse all instructions in complicated expression:
+ //
+ // add = Add(C1, C2)
+ // clamp = Clamp(C2, add, add)
+ // exp = Exp(add)
+ // mul = Mul(exp, C3)
+ // sub = Sub(mul, clamp)
+ // tuple = Tuple({sub, sub, mul, C1})
+ //
+ // Notable complexities are repeated operands in a same instruction, different
+ // shapes, use of value in different expressions.
+ auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.1f));
+ auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.1f));
+ auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(9.0f));
+
+ auto add =
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get());
+ auto clamp = HloInstruction::CreateTernary(r0f32_, HloOpcode::kClamp,
+ c2.get(), add.get(), add.get());
+ auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, add.get());
+ auto mul = HloInstruction::CreateBinary(r0f32_, HloOpcode::kMultiply,
+ exp.get(), c3.get());
+ auto sub = HloInstruction::CreateBinary(r0f32_, HloOpcode::kSubtract,
+ mul.get(), clamp.get());
+ auto tuple =
+ HloInstruction::CreateTuple({sub.get(), sub.get(), mul.get(), c1.get()});
+
+ auto fusion = HloInstruction::CreateFusion(
+ r0f32_, HloInstruction::FusionKind::kLoop, tuple.get());
+ fusion->FuseInstruction(sub.get());
+ fusion->FuseInstruction(mul.get());
+ fusion->FuseInstruction(exp.get());
+ fusion->FuseInstruction(clamp.get());
+ fusion->FuseInstruction(add.get());
+
+ // Operands in the fusion instruction's operands() vector should be in the
+ // order in which their users were added fused.
+ EXPECT_IVEC(fusion->operands(), c1.get(), c3.get(), c2.get());
+ EXPECT_ISET(c1->users(), add.get(), tuple.get(), fusion.get());
+}
+
+// Convenience function for comparing two HloInstructions inside of
+// std::unique_ptrs.
+static bool Identical(std::unique_ptr<HloInstruction> instruction1,
+ std::unique_ptr<HloInstruction> instruction2) {
+ // Verify Identical is reflexive for both instructions.
+ EXPECT_TRUE(instruction1->Identical(*instruction1));
+ EXPECT_TRUE(instruction2->Identical(*instruction2));
+
+ bool is_equal = instruction1->Identical(*instruction2);
+ // Verify Identical is symmetric.
+ EXPECT_EQ(is_equal, instruction2->Identical(*instruction1));
+ return is_equal;
+}
+
+TEST_F(HloInstructionTest, IdenticalInstructions) {
+ // Test HloInstruction::Identical with some subset of instructions types.
+
+ // Create a set of random constant operands to use below. Make them matrices
+ // so dimensions are interesting.
+ auto operand1 = HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}));
+ auto operand2 = HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{10.0, 20.0}, {30.0, 40.0}}));
+ auto vector_operand = HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({42.0, 123.0}));
+ Shape shape = operand1->shape();
+
+ // Convenient short names for the operands.
+ HloInstruction* op1 = operand1.get();
+ HloInstruction* op2 = operand2.get();
+
+ // Operations which only depend on their operands and opcode.
+ EXPECT_TRUE(
+ Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1)));
+ EXPECT_FALSE(
+ Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2)));
+ EXPECT_FALSE(
+ Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1),
+ HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1)));
+
+ // Tuples.
+ EXPECT_TRUE(Identical(HloInstruction::CreateTuple({op1, op2}),
+ HloInstruction::CreateTuple({op1, op2})));
+ EXPECT_FALSE(Identical(HloInstruction::CreateTuple({op1, op2}),
+ HloInstruction::CreateTuple({op2, op1})));
+
+ // Broadcasts.
+ EXPECT_TRUE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
+ HloInstruction::CreateBroadcast(shape, op1, {0, 1})));
+ EXPECT_FALSE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}),
+ HloInstruction::CreateBroadcast(shape, op1, {1, 0})));
+ Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42});
+ Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123});
+ EXPECT_FALSE(
+ Identical(HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}),
+ HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1})));
+
+ // Binary operands.
+ EXPECT_TRUE(Identical(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2)));
+ EXPECT_FALSE(Identical(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
+ HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1)));
+ EXPECT_FALSE(Identical(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2),
+ HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2)));
+}
+
+TEST_F(HloInstructionTest, FunctionVisitor) {
+ // Verify the function visitor HloInstruction::Accept visits all instructions
+ // from a root properly given the following graph:
+ //
+ // param
+ // / \
+ // negate exp
+ // \ /
+ // add
+ const Shape f32 = ShapeUtil::MakeShape(F32, {});
+ auto param = HloInstruction::CreateParameter(0, f32, "0");
+ auto negate =
+ HloInstruction::CreateUnary(f32, HloOpcode::kNegate, param.get());
+ auto exp = HloInstruction::CreateUnary(f32, HloOpcode::kExp, param.get());
+ auto add = HloInstruction::CreateBinary(f32, HloOpcode::kAdd, negate.get(),
+ exp.get());
+
+ int visit_num = 0;
+ std::unordered_map<HloInstruction*, int> visit_order;
+ EXPECT_IS_OK(add->Accept([&visit_num, &visit_order](HloInstruction* inst) {
+ EXPECT_EQ(0, visit_order.count(inst));
+ visit_order[inst] = visit_num;
+ visit_num++;
+ return Status::OK();
+ }));
+
+ EXPECT_EQ(0, visit_order.at(param.get()));
+ // negate and exp can be visited in an arbitrary order.
+ EXPECT_TRUE(visit_order.at(exp.get()) == 1 || visit_order.at(exp.get()) == 2);
+ EXPECT_TRUE(visit_order.at(negate.get()) == 1 ||
+ visit_order.at(negate.get()) == 2);
+ EXPECT_NE(visit_order.at(exp.get()), visit_order.at(negate.get()));
+ EXPECT_EQ(3, visit_order.at(add.get()));
+}
+
+TEST_F(HloInstructionTest, FullyElementwise) {
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
+ auto x = HloInstruction::CreateParameter(0, r1f32, "x");
+ auto y = HloInstruction::CreateParameter(1, r1f32, "y");
+ auto add =
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, x.get(), y.get());
+ EXPECT_TRUE(add->IsElementwise());
+ for (int i = 0; i < add->operand_count(); ++i) {
+ EXPECT_TRUE(add->IsElementwiseOnOperand(i));
+ }
+}
+
+TEST_F(HloInstructionTest, PartiallyElementwise) {
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
+ const Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 5});
+
+ // Fused expression:
+ //
+ // p0 p1 p2 p3
+ // \ / / |
+ // mul / |
+ // \ / |
+ // div broadcast
+ // \ /
+ // max
+ //
+ // The fusion instruction is not elementwise on p3 because the broadcast is
+ // not elementwise.
+ HloComputation::Builder builder("PartiallyElementwise");
+ HloInstruction* p0 =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r2f32, "p0"));
+ HloInstruction* p1 =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r2f32, "p1"));
+ HloInstruction* p2 =
+ builder.AddInstruction(HloInstruction::CreateParameter(2, r2f32, "p2"));
+ HloInstruction* p3 =
+ builder.AddInstruction(HloInstruction::CreateParameter(3, r1f32, "p3"));
+ HloInstruction* mul = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kMultiply, p0, p1));
+ HloInstruction* div = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, mul, p2));
+ // Dimension 0 of shape [5] is mapped to dimension 1 of shape [3x5].
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, p3, {1}));
+ HloInstruction* max = builder.AddInstruction(
+ HloInstruction::CreateBinary(r2f32, HloOpcode::kMaximum, div, broadcast));
+
+ auto computation = builder.Build();
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {max, broadcast, div, mul}, HloInstruction::FusionKind::kLoop);
+ EXPECT_FALSE(fusion->IsElementwise());
+ for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
+ ++operand_idx) {
+ const HloInstruction* operand = fusion->operand(operand_idx);
+ if (operand == p3) {
+ EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
+ } else {
+ EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
+ }
+ }
+}
+
+TEST_F(HloInstructionTest, PartiallyElementwiseWithReuse) {
+ // Fused expression:
+ //
+ // x y
+ // \ / \
+ // min broadcast
+ // \ /
+ // sub
+ //
+ // The fusion instruction is elementwise on `x` because the only path from x
+ // to sub contains only elementwise operations. It is not elementwise on `y`
+ // because the path y->broadcast->sub is not all elementwise.
+ const Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ const Shape r1f32 = ShapeUtil::MakeShape(F32, {5});
+
+ HloComputation::Builder builder("PartiallyElementwiseWithReuse");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "y"));
+ HloInstruction* min = builder.AddInstruction(
+ HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, x, y));
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(r1f32, y, {0}));
+ HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
+ r1f32, HloOpcode::kSubtract, min, broadcast));
+
+ auto computation = builder.Build();
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {sub, broadcast, min}, HloInstruction::FusionKind::kLoop);
+ EXPECT_FALSE(fusion->IsElementwise());
+ for (int64 operand_idx = 0; operand_idx < fusion->operand_count();
+ ++operand_idx) {
+ if (fusion->operand(operand_idx) == x) {
+ EXPECT_TRUE(fusion->IsElementwiseOnOperand(operand_idx));
+ } else {
+ EXPECT_FALSE(fusion->IsElementwiseOnOperand(operand_idx));
+ }
+ }
+}
+
+TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
+ // Fused expression:
+ //
+ // x y
+ // | |
+ // | transpose
+ // \ /
+ // dot
+ //
+ // Tests that shapes aren't mangled by Clone().
+ const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
+ const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10});
+ const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20});
+ const Shape sout = ShapeUtil::MakeShape(F32, {5, 20});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape));
+
+ auto computation = builder.Build();
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {dot, reshape}, HloInstruction::FusionKind::kTransposeDot);
+
+ auto fusion2 = fusion->Clone();
+ const HloInstruction* root = fusion->fused_expression_root();
+ const HloInstruction* root2 = fusion2->fused_expression_root();
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), root2->shape()));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(root->operand(0)->shape(), root2->operand(0)->shape()));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape()));
+ EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(),
+ root2->operand(1)->operand(0)->shape()));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
new file mode 100644
index 0000000000..35110f3946
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -0,0 +1,269 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+#include <iterator>
+#include <set>
+#include <sstream>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+HloComputation* HloModule::AddEntryComputation(
+ std::unique_ptr<HloComputation> computation) {
+ CHECK_EQ(nullptr, entry_computation_);
+ entry_computation_ = computation.get();
+ computation->set_parent(this);
+ computations_.push_back(std::move(computation));
+ return computations_.back().get();
+}
+
+HloComputation* HloModule::AddEmbeddedComputation(
+ std::unique_ptr<HloComputation> computation) {
+ computation->set_parent(this);
+ computations_.push_back(std::move(computation));
+ return computations_.back().get();
+}
+
+void HloModule::ReplaceComputations(
+ const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
+ // Replace all uses of non-canonical computations with their
+ // representatives.
+ std::vector<std::unique_ptr<HloComputation>> new_computations;
+ new_computations.reserve(computations_.size());
+
+ for (std::unique_ptr<HloComputation>& computation : computations_) {
+ for (auto& instruction : computation->instructions()) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kCall:
+ case HloOpcode::kMap:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow: {
+ HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
+ replacements, instruction->to_apply(), nullptr);
+ if (new_arg != nullptr) {
+ instruction->set_to_apply(new_arg);
+ }
+ break;
+ }
+ case HloOpcode::kWhile: {
+ HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
+ replacements, instruction->while_condition(), nullptr);
+ if (new_condition != nullptr) {
+ instruction->set_while_condition(new_condition);
+ }
+ HloComputation* new_body = tensorflow::gtl::FindWithDefault(
+ replacements, instruction->while_body(), nullptr);
+ if (new_body != nullptr) {
+ instruction->set_while_body(new_body);
+ }
+ break;
+ }
+ case HloOpcode::kSelectAndScatter: {
+ HloComputation* new_select = tensorflow::gtl::FindWithDefault(
+ replacements, instruction->select(), nullptr);
+ if (new_select != nullptr) {
+ instruction->set_select(new_select);
+ }
+ HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
+ replacements, instruction->scatter(), nullptr);
+ if (new_scatter != nullptr) {
+ instruction->set_scatter(new_scatter);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+
+ if (replacements.find(computation.get()) == replacements.end()) {
+ new_computations.push_back(std::move(computation));
+ }
+ }
+
+ // Replace entry_computation if necessary.
+ entry_computation_ = tensorflow::gtl::FindWithDefault(
+ replacements, entry_computation_, entry_computation_);
+
+ computations_ = std::move(new_computations);
+}
+
+string HloModule::ToString() const {
+ std::ostringstream s;
+ s << "HloModule " << name() << ":\n\n";
+ s << "ENTRY " << entry_computation()->ToString() << "\n\n";
+ for (const std::unique_ptr<HloComputation>& computation : computations_) {
+ if (computation.get() != entry_computation()) {
+ s << computation->ToString() << "\n\n";
+ }
+ }
+ return s.str();
+}
+
+namespace {
+// Returns whether `hlo` is used outside the given subcomputation.
+// `instructions_in_subcomputation` is the instruction set of the given
+// subcomputation.
+bool IsUsedOutsideSubcomputation(
+ const HloInstruction& hlo,
+ const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) {
+ for (HloInstruction* user : hlo.users()) {
+ if (!instructions_in_subcomputation.count(user)) {
+ return true;
+ }
+ }
+ return false;
+}
+} // anonymous namespace
+
+HloInstruction* HloModule::OutlineExpressionFromComputation(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ const string& outlined_computation_name, HloComputation* computation) {
+ auto builder = HloComputation::Builder(outlined_computation_name);
+
+ // A map from original instructions to their counterparts in the new outlined
+ // function.
+ std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions;
+ // A set that contains all instructions to be outlined.
+ std::unordered_set<HloInstruction*> instruction_set_to_outline(
+ instructions_to_outline.begin(), instructions_to_outline.end());
+ std::vector<HloInstruction*> arguments;
+ std::vector<HloInstruction*> outputs;
+ int64 parameter_count = 0;
+ for (HloInstruction* instruction_to_outline : instructions_to_outline) {
+ // Clone the original instruction.
+ HloInstruction* outlined_instruction =
+ builder.AddInstruction(instruction_to_outline->Clone());
+
+ // Replace its operands to their counterparts in the new function.
+ for (int64 operand_num = 0;
+ operand_num < outlined_instruction->operand_count(); ++operand_num) {
+ HloInstruction* old_operand =
+ outlined_instruction->mutable_operand(operand_num);
+
+ HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
+ if (*operand_slot == nullptr) {
+ // Because instructions_to_outline is in topological order, if
+ // old_operand is not in outlined_instructions, old_operand must be an
+ // input of the outlined subcomputation and thus should be represented
+ // as a parameter in the new function.
+ arguments.push_back(old_operand);
+ *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
+ parameter_count, old_operand->shape(), ""));
+ ++parameter_count;
+ }
+ outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot);
+ }
+
+ // Insert the new instruction into the outlined_instructions map.
+ InsertOrDie(&outlined_instructions, instruction_to_outline,
+ outlined_instruction);
+
+ // Mark instruction_to_outline an output if it is used outside the
+ // subcomputation or is the output of the original computation (i.e. used
+ // externally).
+ if (instruction_to_outline->user_count() == 0 ||
+ IsUsedOutsideSubcomputation(*instruction_to_outline,
+ instruction_set_to_outline)) {
+ outputs.push_back(instruction_to_outline);
+ }
+ }
+
+ if (outputs.size() != 1) {
+ string error_message =
+ "The subcomputation to outline has multiple outputs:\n";
+ for (HloInstruction* output : outputs) {
+ tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
+ }
+ LOG(FATAL) << error_message;
+ }
+ HloInstruction* output = outputs[0];
+
+ // Creates a call to the nested computation.
+ HloComputation* nested_computation = AddEmbeddedComputation(
+ builder.Build(FindOrDie(outlined_instructions, output)));
+ HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
+ output->shape(), arguments, nested_computation));
+
+ VLOG(2) << "Outlining the following instructions";
+ for (auto* instruction_to_outline : instructions_to_outline) {
+ VLOG(2) << " " << instruction_to_outline->ToString();
+ }
+ VLOG(2) << "as a call " << call->ToString();
+ VLOG(2) << "to " << nested_computation->ToString();
+
+ computation->ReplaceUsesOfInstruction(output, call);
+ for (auto i = instructions_to_outline.rbegin();
+ i != instructions_to_outline.rend(); ++i) {
+ computation->RemoveInstruction(*i);
+ }
+
+ return call;
+}
+
+std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
+ // First determine all root computations by building a set of nonroot
+ // computations (computations which are called by an instruction in the
+ // module).
+ std::set<HloComputation*> nonroot_computations;
+ for (auto& computation : computations_) {
+ for (auto& instruction : computation->instructions()) {
+ for (auto called_computation : instruction->MakeCalledComputationsSet()) {
+ nonroot_computations.insert(called_computation);
+ }
+ }
+ }
+
+ // Keep track of computations which have already been added to the post
+ // order. This prevents duplication as an embedded computation may be called
+ // from two different root computations.
+ std::set<HloComputation*> added_computations;
+ std::list<HloComputation*> post_order;
+ for (auto& computation : computations_) {
+ if (nonroot_computations.count(computation.get()) == 0) {
+ for (HloComputation* embedded_computation :
+ computation->MakeEmbeddedComputationsList()) {
+ if (added_computations.count(embedded_computation) == 0) {
+ post_order.push_back(embedded_computation);
+ added_computations.insert(embedded_computation);
+ }
+ }
+ // Root computations should only be encountered once.
+ CHECK_EQ(0, added_computations.count(computation.get()));
+ post_order.push_back(computation.get());
+ added_computations.insert(computation.get());
+ }
+ }
+ CHECK_EQ(post_order.size(), computations_.size());
+ return post_order;
+}
+
+uint64 HloModule::RandomNew64() const {
+ tensorflow::mutex_lock l(rng_mutex_);
+ return rng_();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h
new file mode 100644
index 0000000000..d598750da6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module.h
@@ -0,0 +1,132 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
+
+#include <list>
+#include <memory>
+#include <random>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace xla {
+
+// Describes a compilation unit at the HLO level.
+//
+// A HLO module contains one or more HLO computations. The module contains one
+// "entry" computation which produces the result. The module also includes any
+// embedded computations used by instructions such as "map" and "reduce". All
+// computations are owned by the module.
+class HloModule {
+ public:
+ explicit HloModule(const string& name,
+ const VersionedComputationHandle& entry_computation_handle)
+ : name_(name),
+ entry_computation_(nullptr),
+ has_entry_computation_handle_(true),
+ entry_computation_handle_(entry_computation_handle) {}
+
+ // Constructor without a versioned computation handle. This constructor should
+ // only be used for HloModules used outside of the XLA service (eg
+ // tests). The versioned handle is used by the service in the compilation
+ // cache.
+ explicit HloModule(const string& name)
+ : name_(name), entry_computation_(nullptr) {}
+
+ // Adds an entry computation to the module. A module can only have one entry
+ // computation. Returns a pointer to the newly added computation.
+ HloComputation* AddEntryComputation(
+ std::unique_ptr<HloComputation> computation);
+
+ // Adds an embedded computation to the module.
+ HloComputation* AddEmbeddedComputation(
+ std::unique_ptr<HloComputation> computation);
+
+ // Replaces all uses of computations that are keys of 'replacements' with
+ // the corresponding values in 'replacements'. Replaces the entry computation,
+ // if applicable.
+ //
+ // This function iterates over all instructions in the module to find
+ // computations to replace. We could speed it up by keeping track of users of
+ // computations.
+ void ReplaceComputations(
+ const std::unordered_map<HloComputation*, HloComputation*>& replacements);
+
+ const string& name() const { return name_; }
+
+ // Return a pointer to the entry computation of the module..
+ HloComputation* entry_computation() const {
+ CHECK_NE(nullptr, entry_computation_);
+ return entry_computation_;
+ }
+
+ const VersionedComputationHandle& entry_computation_handle() const {
+ return entry_computation_handle_;
+ }
+
+ const std::vector<std::unique_ptr<HloComputation>>& computations() const {
+ return computations_;
+ }
+
+ // Compute and return a post order of all computations in the module. The sort
+ // is defined like so: if computation A has an instruction which calls
+ // computation B, then A will appear after B in the sort.
+ std::list<HloComputation*> MakeComputationPostOrder() const;
+
+ string ToString() const;
+
+ // Outlines the given expression from the given computation.
+ // instructions_to_outline contains the instructions that form the expression.
+ //
+ // Precondition: instructions in instructions_to_outline are in topological
+ // order (root of outlined instructions last). TODO(jingyue): takes a set of
+ // instructions and topologically sorts them.
+ HloInstruction* OutlineExpressionFromComputation(
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
+ const string& outlined_computation_name, HloComputation* computation);
+
+ // Returns a randomly generated uint64.
+ uint64 RandomNew64() const;
+
+ private:
+ const string name_;
+ HloComputation* entry_computation_;
+ std::vector<std::unique_ptr<HloComputation>> computations_;
+
+ // Random number generator engine to use when generating random numbers per
+ // HloModule compilation.
+ // TODO(b/25995601): Replace with better seed setting or dev/random for
+ // where we don't need deterministic execution.
+ mutable std::mt19937_64 rng_{42};
+ mutable tensorflow::mutex rng_mutex_;
+
+ // Versioned handle of the entry computation of the module.
+ bool has_entry_computation_handle_ = false;
+ VersionedComputationHandle entry_computation_handle_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc
new file mode 100644
index 0000000000..033b129e34
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_config.cc
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+
+HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape)
+ : entry_computation_layout_(program_shape) {}
+
+string HloModuleConfig::compilation_cache_key() const {
+ string key = tensorflow::strings::StrCat("profiling=", hlo_profiling_enabled_,
+ "::", "hybrid=", has_hybrid_result_);
+ tensorflow::strings::StrAppend(&key, "::(");
+ std::vector<string> params;
+ for (const ShapeLayout& param_layout :
+ entry_computation_layout_.parameter_layouts()) {
+ params.push_back(param_layout.shape().SerializeAsString());
+ }
+ tensorflow::strings::StrAppend(
+ &key, tensorflow::str_util::Join(params, ", "), ") => ",
+ entry_computation_layout_.result_shape().SerializeAsString());
+ if (seed_ != 0) {
+ // TODO(b/32083678): force recompilation to reset global state.
+ static int counter = 0;
+ tensorflow::strings::StrAppend(&key, "forcing recompile ", counter++);
+ }
+ if (replica_count() != 1) {
+ tensorflow::strings::StrAppend(&key, "::replica_count=", replica_count());
+ }
+ return key;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h
new file mode 100644
index 0000000000..f081790869
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_config.h
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// This class gathers all settings and values which affect the compiled
+// executable outside of the HLO code itself. This include layouts of inputs and
+// outputs to the module and settings such as HLO profiling. Together the
+// HloModule and HloModuleConfig unambiguously determine a particular
+// executable.
+class HloModuleConfig {
+ public:
+ explicit HloModuleConfig(const ProgramShape& program_shape);
+
+ // Return a reference to the layout of the entry computation.
+ const ComputationLayout& entry_computation_layout() const {
+ return entry_computation_layout_;
+ }
+ ComputationLayout* mutable_entry_computation_layout() {
+ return &entry_computation_layout_;
+ }
+
+ // Sets/returns whether to enable HLO-level profiling.
+ bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; }
+ void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; }
+
+ bool has_hybrid_result() const { return has_hybrid_result_; }
+ void set_has_hybrid_result(bool has_hybrid_result) {
+ has_hybrid_result_ = has_hybrid_result;
+ }
+
+ // Sets/returns the module seed set during execution.
+ void set_seed(uint64 seed) { seed_ = seed; }
+ uint64 seed() const { return seed_; }
+
+ void set_replica_count(int64 replica_count) {
+ replica_count_ = replica_count;
+ }
+ int64 replica_count() const { return replica_count_; }
+
+ // Return a string which unambiguously represents all the fields of this data
+ // structure. Used for generating a cache key for storing the compiled
+ // executable.
+ string compilation_cache_key() const;
+
+ private:
+ ComputationLayout entry_computation_layout_;
+
+ // Whether to enable HLO-level profiling.
+ bool hlo_profiling_enabled_ = false;
+
+ // If this flag is true, the generated executable will return a ShapedBuffer
+ // holding the result of the computation. In a ShapedBuffer, tuples have their
+ // structure held in host memory and the element arrays (leaves of the tuple
+ // structure) stored in device memory. The ShapedBuffer is considered "hybrid"
+ // because its leaves are on device but its structure is stored on
+ // host. Otherwise, if this flag is false, the generated executable will
+ // return a DeviceMemoryBase where the result is held entirely in device
+ // memory.
+ bool has_hybrid_result_ = false;
+
+ // Module/graph-level seed handle.
+ uint64 seed_ = 0;
+
+ // The number of replicas to compile this binary for.
+ int64 replica_count_ = 1;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc
new file mode 100644
index 0000000000..0f4252522d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_module_test.cc
@@ -0,0 +1,101 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+
+namespace {
+
+class HloModuleTest : public HloTestBase {
+ protected:
+ HloModuleTest() {}
+
+ // Create a computation which returns a constant.
+ std::unique_ptr<HloComputation> CreateConstantComputation() {
+ auto builder = HloComputation::Builder("Constant");
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ return builder.Build();
+ }
+
+ // Creates a computation which calls the given zero-parameter computations.
+ std::unique_ptr<HloComputation> CreateCallComputation(
+ tensorflow::gtl::ArraySlice<HloComputation*> computations) {
+ auto builder = HloComputation::Builder("Call");
+ for (auto computation : computations) {
+ builder.AddInstruction(
+ HloInstruction::CreateCall(r0f32_, {}, computation));
+ }
+ return builder.Build();
+ }
+
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+};
+
+TEST_F(HloModuleTest, OneComputationPostOrder) {
+ // Create a module with a single computation.
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(CreateConstantComputation());
+
+ EXPECT_EQ(module->MakeComputationPostOrder().front(), computation);
+}
+
+TEST_F(HloModuleTest, TwoComputationsPostOrder) {
+ // Create a module with two unconnected computations.
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation1 = module->AddEntryComputation(CreateConstantComputation());
+ auto computation2 =
+ module->AddEmbeddedComputation(CreateConstantComputation());
+
+ EXPECT_MATCH(
+ testing::ListToVec<HloComputation*>(module->MakeComputationPostOrder()),
+ testing::UnorderedMatcher<HloComputation*>(computation1, computation2));
+}
+
+TEST_F(HloModuleTest, DiamondComputationsPostOrder) {
+ // Create a module with a diamond call graph of computations.
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation1 =
+ module->AddEmbeddedComputation(CreateConstantComputation());
+ auto computation2 =
+ module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+ auto computation3 =
+ module->AddEmbeddedComputation(CreateCallComputation({computation1}));
+ auto computation4 = module->AddEntryComputation(
+ CreateCallComputation({computation2, computation3}));
+
+ auto post_order = module->MakeComputationPostOrder();
+ EXPECT_MATCH(testing::ListToVec<HloComputation*>(post_order),
+ testing::UnorderedMatcher<HloComputation*>(
+ computation1, computation2, computation3, computation4));
+ EXPECT_EQ(post_order.back(), computation4);
+ EXPECT_EQ(post_order.front(), computation1);
+}
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
new file mode 100644
index 0000000000..2ae3858163
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -0,0 +1,164 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+string HloOpcodeString(HloOpcode opcode) {
+ switch (opcode) {
+ case HloOpcode::kAbs:
+ return "abs";
+ case HloOpcode::kAdd:
+ return "add";
+ case HloOpcode::kBitcast:
+ return "bitcast";
+ case HloOpcode::kBroadcast:
+ return "broadcast";
+ case HloOpcode::kCall:
+ return "call";
+ case HloOpcode::kClamp:
+ return "clamp";
+ case HloOpcode::kConcatenate:
+ return "concatenate";
+ case HloOpcode::kConstant:
+ return "constant";
+ case HloOpcode::kConvert:
+ return "convert";
+ case HloOpcode::kConvolution:
+ return "convolution";
+ case HloOpcode::kCrossReplicaSum:
+ return "cross-replica-sum";
+ case HloOpcode::kCustomCall:
+ return "custom-call";
+ case HloOpcode::kCopy:
+ return "copy";
+ case HloOpcode::kDivide:
+ return "divide";
+ case HloOpcode::kDot:
+ return "dot";
+ case HloOpcode::kDynamicSlice:
+ return "dynamic-slice";
+ case HloOpcode::kDynamicUpdateSlice:
+ return "dynamic-update-slice";
+ case HloOpcode::kEq:
+ return "equal-to";
+ case HloOpcode::kExp:
+ return "exponential";
+ case HloOpcode::kFloor:
+ return "floor";
+ case HloOpcode::kCeil:
+ return "ceil";
+ case HloOpcode::kFusion:
+ return "fusion";
+ case HloOpcode::kGe:
+ return "greater-than-or-equal-to";
+ case HloOpcode::kGetTupleElement:
+ return "get-tuple-element";
+ case HloOpcode::kGt:
+ return "greater-than";
+ case HloOpcode::kIndex:
+ return "index";
+ case HloOpcode::kInfeed:
+ return "infeed";
+ case HloOpcode::kLe:
+ return "less-than-or-equal-to";
+ case HloOpcode::kLog:
+ return "log";
+ case HloOpcode::kLogicalAnd:
+ return "logical-and";
+ case HloOpcode::kLogicalOr:
+ return "logical-or";
+ case HloOpcode::kLogicalNot:
+ return "logical-not";
+ case HloOpcode::kLt:
+ return "less-than";
+ case HloOpcode::kMap:
+ return "map";
+ case HloOpcode::kMaximum:
+ return "maximum";
+ case HloOpcode::kMinimum:
+ return "minimum";
+ case HloOpcode::kMultiply:
+ return "multiply";
+ case HloOpcode::kNe:
+ return "not-equal-to";
+ case HloOpcode::kNegate:
+ return "negate";
+ case HloOpcode::kPad:
+ return "pad";
+ case HloOpcode::kParameter:
+ return "parameter";
+ case HloOpcode::kPower:
+ return "power";
+ case HloOpcode::kRecv:
+ return "recv";
+ case HloOpcode::kReduce:
+ return "reduce";
+ case HloOpcode::kReduceWindow:
+ return "reduce-window";
+ case HloOpcode::kRemainder:
+ return "remainder";
+ case HloOpcode::kReshape:
+ return "reshape";
+ case HloOpcode::kReverse:
+ return "reverse";
+ case HloOpcode::kRng:
+ return "rng";
+ case HloOpcode::kSelectAndScatter:
+ return "select-and-scatter";
+ case HloOpcode::kSelect:
+ return "select";
+ case HloOpcode::kSend:
+ return "send";
+ case HloOpcode::kSign:
+ return "sign";
+ case HloOpcode::kSlice:
+ return "slice";
+ case HloOpcode::kSort:
+ return "sort";
+ case HloOpcode::kSubtract:
+ return "subtract";
+ case HloOpcode::kTanh:
+ return "tanh";
+ case HloOpcode::kTrace:
+ return "trace";
+ case HloOpcode::kTranspose:
+ return "transpose";
+ case HloOpcode::kTuple:
+ return "tuple";
+ case HloOpcode::kUpdate:
+ return "update";
+ case HloOpcode::kWhile:
+ return "while";
+ }
+}
+
+bool HloOpcodeIsComparison(HloOpcode opcode) {
+ switch (opcode) {
+ case HloOpcode::kGe:
+ case HloOpcode::kGt:
+ case HloOpcode::kLe:
+ case HloOpcode::kLt:
+ case HloOpcode::kEq:
+ case HloOpcode::kNe:
+ return true;
+ default:
+ return false;
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
new file mode 100644
index 0000000000..a8631b9c7d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
+
+#include <iosfwd>
+#include <string>
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+// High-level optimizer instruction opcodes -- these are linear-algebra level
+// opcodes. They are a flattened form of the UnaryOp, BinaryOp, ... opcodes
+// present in the XLA service protobuf.
+//
+// See the XLA documentation for the semantics of each opcode.
+enum class HloOpcode {
+ kAbs,
+ kAdd,
+ kBitcast,
+ kBroadcast,
+ kCall,
+ kCeil,
+ kClamp,
+ kConcatenate,
+ kConstant,
+ kConvert,
+ kConvolution,
+ kCopy,
+ kCrossReplicaSum,
+ kCustomCall,
+ kDivide,
+ kDot,
+ kDynamicSlice,
+ kDynamicUpdateSlice,
+ kEq,
+ kExp,
+ kFloor,
+ kFusion,
+ kGe,
+ kGetTupleElement,
+ kGt,
+ kIndex,
+ kInfeed,
+ kLe,
+ kLog,
+ kLogicalAnd,
+ kLogicalNot,
+ kLogicalOr,
+ kLt,
+ kMap,
+ kMaximum,
+ kMinimum,
+ kMultiply,
+ kNe,
+ kNegate,
+ kPad,
+ kParameter,
+ kPower,
+ kRecv,
+ kReduce,
+ kReduceWindow,
+ kRemainder,
+ kReshape,
+ kReverse,
+ kRng,
+ kSelect,
+ kSelectAndScatter,
+ kSend,
+ kSign,
+ kSlice,
+ kSort,
+ kSubtract,
+ kTanh,
+ kTrace,
+ kTranspose,
+ kTuple,
+ kUpdate,
+ kWhile,
+};
+
+// Returns a string representation of the opcode.
+string HloOpcodeString(HloOpcode opcode);
+
+inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) {
+ return os << HloOpcodeString(opcode);
+}
+
+// Returns true iff the given opcode is a comparison operation.
+bool HloOpcodeIsComparison(HloOpcode opcode);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_OPCODE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
new file mode 100644
index 0000000000..0b64c16fdc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+// This test verifies that an example HloOpcode stringifies as expected.
+TEST(HloOpcodeTest, StringifyMultiply) {
+ ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass.h b/tensorflow/compiler/xla/service/hlo_pass.h
new file mode 100644
index 0000000000..c83d0eed00
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass.h
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Base class for HLO passes. These are used with the HloPassPipeline to
+// organize a sequence of passes.
+class HloPass {
+ public:
+ explicit HloPass(const string& name) : name_(name) {}
+ virtual ~HloPass() {}
+
+ const string& name() const { return name_; }
+
+ // Run the pass on the given HLO module. Return whether it modified the
+ // module.
+ virtual StatusOr<bool> Run(HloModule* module) = 0;
+
+ private:
+ const string name_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HloPass);
+};
+
+// Do an HLO pass to a fix point.
+template <typename Pass>
+class HloPassFix : public Pass {
+ public:
+ template <typename... Args>
+ explicit HloPassFix(Args&&... args) : Pass(args...) {}
+
+ StatusOr<bool> Run(HloModule* module) override {
+ bool changed = false;
+ bool changed_this_iteration = true;
+ while (changed_this_iteration) {
+ TF_ASSIGN_OR_RETURN(changed_this_iteration, Pass::Run(module));
+ changed |= changed_this_iteration;
+ }
+ return changed;
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
new file mode 100644
index 0000000000..fcfd242a85
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
+
+#include <functional>
+
+#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
+ legacy_flags::HloPassPipelineFlags* flags =
+ legacy_flags::GetHloPassPipelineFlags();
+ std::vector<string> tmp =
+ tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
+ tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
+
+ string prefix = name() + ": pipeline start";
+ bool changed = false;
+ string message;
+ for (auto& pass : passes_) {
+ if (!disabled_passes.empty() && disabled_passes.count(pass->name()) > 0) {
+ continue;
+ }
+
+ // Emit label containing: "after foo-pass, before bar-pass".
+ message.clear();
+ tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name());
+ dumper_(*module, message);
+
+ VLOG(2) << "HLO " << message << ":";
+ XLA_VLOG_LINES(2, module->ToString());
+
+ TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
+
+ changed |= changed_this_pass;
+ prefix.clear();
+ tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name());
+ }
+ dumper_(*module, prefix + ", pipeline end");
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.h b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
new file mode 100644
index 0000000000..f49eed8cba
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Pipeline of HLO passes.
+class HloPassPipeline : public HloPass {
+ public:
+ explicit HloPassPipeline(const string& name,
+ const Compiler::HloDumper& dumper)
+ : HloPass(name), dumper_(dumper) {}
+
+ // Add a pass to the pipeline. It should be called with the arguments for the
+ // pass constructor:
+ //
+ // pipeline.AddPass<FooPass>(constructor_arg1, constructor_arg2);
+ //
+ // Returns a reference to the added pass.
+ template <typename T, typename... Args>
+ T& AddPass(Args&&... args) {
+ auto pass = new T(std::forward<Args>(args)...);
+ passes_.push_back(std::unique_ptr<T>(pass));
+ return *pass;
+ }
+
+ // Run all passes on the given HLO module.
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ Compiler::HloDumper dumper_;
+ std::vector<std::unique_ptr<HloPass>> passes_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PASS_PIPELINE_H_
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
new file mode 100644
index 0000000000..1556d1772f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -0,0 +1,89 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_query.h"
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+
+namespace xla {
+namespace hlo_query {
+
+bool IsConstantR0F32(HloInstruction* instruction, float* out) {
+ if (instruction->opcode() == HloOpcode::kConstant &&
+ ShapeUtil::IsScalarF32(instruction->shape())) {
+ *out = LiteralUtil::Get<float>(instruction->literal(), {});
+ return true;
+ }
+
+ return false;
+}
+
+bool AllOperandsAreParameters(const HloInstruction& instruction) {
+ for (const auto& operand : instruction.operands()) {
+ if (operand->opcode() != HloOpcode::kParameter) {
+ return false;
+ }
+ }
+ return true;
+}
+
+HloInstruction* GetMatchingOperand(
+ std::function<bool(const HloInstruction*)> matcher,
+ HloInstruction* instruction) {
+ for (HloInstruction* op : instruction->operands()) {
+ if (matcher(op)) {
+ return op;
+ }
+ }
+ return nullptr;
+}
+
+bool MatchBinaryInstructionOperand(
+ std::function<bool(const HloInstruction*)> matcher,
+ HloInstruction* instruction, HloInstruction** matching_operand,
+ HloInstruction** other_operand) {
+ CHECK_EQ(instruction->operand_count(), 2);
+ if (matcher(instruction->operand(0))) {
+ *matching_operand = instruction->mutable_operand(0);
+ *other_operand = instruction->mutable_operand(1);
+ return true;
+ }
+ if (matcher(instruction->operand(1))) {
+ *matching_operand = instruction->mutable_operand(1);
+ *other_operand = instruction->mutable_operand(0);
+ return true;
+ }
+ return false;
+}
+
+bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
+ HloInstruction* instruction,
+ HloInstruction** matching_operand,
+ HloInstruction** other_operand) {
+ return MatchBinaryInstructionOperand(
+ [opcode](const HloInstruction* instruction) {
+ return instruction->opcode() == opcode;
+ },
+ instruction, matching_operand, other_operand);
+}
+
+bool IsScalarConstant(const HloInstruction* instruction) {
+ return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
+}
+
+} // namespace hlo_query
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_query.h b/tensorflow/compiler/xla/service/hlo_query.h
new file mode 100644
index 0000000000..864f892e92
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_query.h
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+// Helper interface for making queries about the HLO IR.
+namespace hlo_query {
+
+// Returns whether the instruction provided is a constant rank-0 float32, and
+// if so, places the constant value into out.
+// Precondition: out != nullptr
+bool IsConstantR0F32(HloInstruction* instruction, float* out);
+
+// Returns whether all of an instruction's operands are parameters.
+bool AllOperandsAreParameters(const HloInstruction& instruction);
+
+// Returns whether the instruction is a scalar constant.
+bool IsScalarConstant(const HloInstruction* instruction);
+
+// Returns an operand of an instruction with the given opcode. If there are
+// multiple matching operands, then the first matching operand is returned. If
+// there are no matching operands then nullptr is returned.
+HloInstruction* GetMatchingOperand(
+ std::function<bool(const HloInstruction*)> matcher,
+ HloInstruction* instruction);
+
+// Returns whether a binary instruction has a matching operand. Sets
+// matching_operand to the matching operand and the other operand to
+// other_operand. Note: in the case where both operands match, the first operand
+// of the instruction is returned.
+bool MatchBinaryInstructionOperand(
+ std::function<bool(const HloInstruction*)> matcher,
+ HloInstruction* instruction, HloInstruction** matching_operand,
+ HloInstruction** other_operand);
+
+// Returns whether a binary instruction has a operand with a given opcode.
+// This is a special case of MatchingBinaryInstructionOperand.
+bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
+ HloInstruction* instruction,
+ HloInstruction** matching_operand,
+ HloInstruction** other_operand);
+
+} // namespace hlo_query
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc
new file mode 100644
index 0000000000..460dc5cf64
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+
+#include <unordered_map>
+
+namespace xla {
+
+StatusOr<bool> HloSubcomputationUnification::Run(HloModule* module) {
+ // For each computation C in the module, find the first computation C0 in the
+ // computations_ list that is identical to C, and adds canon[C] = C0.
+ std::unordered_map<HloComputation*, HloComputation*> canon;
+ const auto& computations = module->computations();
+ for (auto i = computations.begin(); i != computations.end(); ++i) {
+ for (auto j = computations.begin(); j < i; ++j) {
+ // Do not waste time comparing `*i` with `*j` if `*j` is not canonical.
+ if (canon.find(j->get()) == canon.end() && **i == **j) {
+ canon[i->get()] = j->get();
+ break;
+ }
+ }
+ }
+
+ if (canon.empty()) {
+ return false;
+ }
+
+ module->ReplaceComputations(canon);
+ return true;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
new file mode 100644
index 0000000000..9ac3d5702d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// Unify subcomputations of a `HloModule`: if any computations are equal, choose
+// one arbitrarily to use and delete the others.
+class HloSubcomputationUnification : public HloPass {
+ public:
+ HloSubcomputationUnification() : HloPass("subcomputation unification") {}
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SUBCOMPUTATION_UNIFICATION_H_
diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
new file mode 100644
index 0000000000..14800b5342
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc
@@ -0,0 +1,205 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+
+namespace xla {
+
+class HloSubcomputationUnificationTest : public HloTestBase {
+ protected:
+ HloSubcomputationUnificationTest() {}
+
+ std::unique_ptr<HloComputation> CreateR0S32IdentityComputation() {
+ auto builder = HloComputation::Builder("Identity");
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
+ return builder.Build();
+ }
+
+ std::unique_ptr<HloComputation> CreateR0S32AdditionComputation() {
+ auto builder = HloComputation::Builder("Addition");
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32_, "x"));
+ auto y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r0s32_, "y"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
+ return builder.Build();
+ }
+
+ std::unique_ptr<HloComputation> CreateR1S32AdditionComputation(
+ const Shape& shape) {
+ auto builder = HloComputation::Builder("Addition");
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "x"));
+ auto y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "y"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, x, y));
+ return builder.Build();
+ }
+
+ Shape r0s32_ = ShapeUtil::MakeShape(S32, {});
+ Shape r0f32_ = ShapeUtil::MakeShape(S32, {});
+ Shape r1s32_5_ = ShapeUtil::MakeShape(S32, {5});
+ Shape r1s32_3_ = ShapeUtil::MakeShape(S32, {3});
+};
+
+TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ auto callee1 =
+ hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
+ auto callee2 =
+ hlo_module->AddEmbeddedComputation(CreateR0S32IdentityComputation());
+
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
+ auto x = builder.AddInstruction(
+ HloInstruction::CreateCall(r0s32_, {constant}, callee1));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateCall(r0s32_, {constant}, callee2));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_NE(x->to_apply(), y->to_apply());
+ if (VLOG_IS_ON(1)) {
+ hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ "before unification", false, false, nullptr);
+ }
+ EXPECT_TRUE(
+ HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie());
+ if (VLOG_IS_ON(1)) {
+ hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ "after unification", false, false, nullptr);
+ }
+ EXPECT_EQ(2, hlo_module->computations().size());
+ EXPECT_EQ(x->to_apply(), y->to_apply());
+}
+
+TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ auto callee1 =
+ hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
+ auto callee2 =
+ hlo_module->AddEmbeddedComputation(CreateR0S32AdditionComputation());
+
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(5)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(3)));
+ auto x = builder.AddInstruction(
+ HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee2));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32_, HloOpcode::kAdd, x, y));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_NE(x->to_apply(), y->to_apply());
+ if (VLOG_IS_ON(1)) {
+ hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ "before unification", false, false, nullptr);
+ }
+ EXPECT_TRUE(
+ HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie());
+ if (VLOG_IS_ON(1)) {
+ hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ "after unification", false, false, nullptr);
+ }
+ EXPECT_EQ(2, hlo_module->computations().size());
+ EXPECT_EQ(x->to_apply(), y->to_apply());
+}
+
+// Do not unify subcomputations with different parameter shapes.
+TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ auto callee1 = hlo_module->AddEmbeddedComputation(
+ CreateR1S32AdditionComputation(r1s32_5_));
+ auto callee2 = hlo_module->AddEmbeddedComputation(
+ CreateR1S32AdditionComputation(r1s32_3_));
+
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r1s32_5_, "param1"));
+ auto param2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r1s32_5_, "param2"));
+ auto x = builder.AddInstruction(
+ HloInstruction::CreateCall(r1s32_5_, {param1, param1}, callee1));
+ auto y = builder.AddInstruction(
+ HloInstruction::CreateCall(r1s32_3_, {param2, param2}, callee2));
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(S32, {8}), {x, y}, 0));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_NE(x->to_apply(), y->to_apply());
+ if (VLOG_IS_ON(1)) {
+ hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ "before unification", false, false, nullptr);
+ }
+ EXPECT_FALSE(
+ HloSubcomputationUnification().Run(hlo_module.get()).ValueOrDie());
+ if (VLOG_IS_ON(1)) {
+ hlo_graph_dumper::DumpGraph(*hlo_module->entry_computation(),
+ "after unification", false, false, nullptr);
+ }
+ EXPECT_EQ(3, hlo_module->computations().size());
+ EXPECT_NE(x->to_apply(), y->to_apply());
+}
+
+// Regression test for b/31466798. Checks that entry_computation is still valid
+// after unification.
+TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) {
+ HloModule module(TestName());
+ for (int i = 0; i < 2; ++i) {
+ HloComputation::Builder builder("pow");
+ auto x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x"));
+ auto y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32_, "y"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32_, HloOpcode::kPower, x, y));
+ if (i == 0) {
+ module.AddEmbeddedComputation(builder.Build());
+ } else {
+ module.AddEntryComputation(builder.Build());
+ }
+ }
+
+ EXPECT_TRUE(HloSubcomputationUnification().Run(&module).ValueOrDie());
+ EXPECT_EQ(1, module.computations().size());
+ EXPECT_EQ(module.computations().front().get(), module.entry_computation());
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc
new file mode 100644
index 0000000000..bf6e3dd84d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/inliner.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/inliner.h"
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_query.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+// InlinerVisitor traverses the HLO computation and inlines maps.
+class InlinerVisitor : public DfsHloVisitorWithDefault {
+ public:
+ explicit InlinerVisitor(HloComputation* computation)
+ : computation_(computation) {}
+
+ // Default visitor action is to do nothing and return OK.
+ Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
+ return Status::OK();
+ }
+
+ Status HandleMap(
+ HloInstruction* map,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> static_operands) override;
+
+ // Runs the visitor on a computation.
+ StatusOr<bool> Run(HloComputation* computation);
+
+ private:
+ // Current HloComputation instance the InlinerVisitor is traversing.
+ HloComputation* computation_;
+
+ // Whether algebraic simplification has occurred.
+ bool changed_ = false;
+};
+
+StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
+ changed_ = false;
+ computation_ = computation;
+ TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
+ return changed_;
+}
+
+Status InlinerVisitor::HandleMap(
+ HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* function,
+ tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) {
+ HloInstruction& root = *function->root_instruction();
+ // TODO(b/29249531): Add DCE pass to remove unused HloComputations.
+ // Only inlining functions that are simply a single operation until a better
+ // profitability model for inlining is defined.
+ if (hlo_query::AllOperandsAreParameters(root)) {
+ if (root.opcode() == HloOpcode::kUpdate ||
+ root.opcode() == HloOpcode::kFusion ||
+ root.opcode() == HloOpcode::kIndex ||
+ root.opcode() == HloOpcode::kParameter ||
+ root.opcode() == HloOpcode::kTrace) {
+ // Cloning not supported for these instructions.
+ return Status::OK();
+ }
+ VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
+ << root.ToShortString();
+ // If the input is a constant then the shape of the constant could be
+ // different than the map shape. Hence, a broadcast is needed, else the
+ // cloned operand with new shape and operands work.
+ if (root.opcode() != HloOpcode::kConstant) {
+ HloInstruction* placed_instruction = computation_->AddInstruction(
+ root.CloneWithNewOperands(map->shape(), operands));
+ computation_->ReplaceInstruction(map, placed_instruction);
+ } else {
+ // The constant is in an embedded computation and needs to be recreated
+ // as part of the computation that the broadcast is inserted into.
+ HloInstruction* constant = computation_->AddInstruction(root.Clone());
+ HloInstruction* placed_instruction = computation_->AddInstruction(
+ HloInstruction::CreateBroadcast(map->shape(), constant, {}));
+ computation_->ReplaceInstruction(map, placed_instruction);
+ }
+ changed_ = true;
+ return Status::OK();
+ }
+
+ return Status::OK();
+}
+
+StatusOr<bool> Inliner::Run(HloModule* module) {
+ InlinerVisitor visitor(/*computation=*/nullptr);
+ bool changed = false;
+ for (const std::unique_ptr<HloComputation>& computation :
+ module->computations()) {
+ TF_ASSIGN_OR_RETURN(bool computation_changed,
+ visitor.Run(computation.get()));
+ changed |= computation_changed;
+ }
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/inliner.h
new file mode 100644
index 0000000000..5d53443b83
--- /dev/null
+++ b/tensorflow/compiler/xla/service/inliner.h
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// A pass which performs inlining. Which can result, for example, in functions
+// that were previously being mapped by Map instead directly applied to the
+// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
+class Inliner : public HloPass {
+ public:
+ Inliner() : HloPass("inline") {}
+ ~Inliner() override = default;
+
+ // Run inlining on the given computation. Returns whether the computation was
+ // changed.
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc
new file mode 100644
index 0000000000..0054edcf6a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/inliner_test.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/inliner.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+using InlinerTest = HloTestBase;
+
+// Test that `map` with `max` is transformed to `max`
+TEST_F(InlinerTest, MapMax) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto max_builder = HloComputation::Builder(TestName());
+ auto param1 = max_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "x"));
+ auto param2 = max_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r0f32, "y"));
+ max_builder.AddInstruction(HloInstruction::CreateBinary(
+ param1->shape(), HloOpcode::kMaximum, param1, param2));
+ auto max_f32 = max_builder.Build();
+
+ auto builder = HloComputation::Builder("MapMaxFunction");
+ auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEmbeddedComputation(std::move(max_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+ HloInstruction* root = hlo_module->entry_computation()->root_instruction();
+ Inliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ root = hlo_module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMaximum);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
+ LiteralTestUtil::ExpectEqual(*result, *expected);
+}
+
+// Test that `constant` function is changed to `broadcast`.
+TEST_F(InlinerTest, MapConstant) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto const2_builder = HloComputation::Builder(TestName());
+ auto param1 = const2_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "x"));
+ (void)param1;
+ const2_builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
+ auto const2_f32 = const2_builder.Build();
+
+ auto builder = HloComputation::Builder("MapConstFunction");
+ auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEmbeddedComputation(std::move(const2_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+ HloInstruction* root = hlo_module->entry_computation()->root_instruction();
+ Inliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ root = hlo_module->entry_computation()->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kBroadcast);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+ auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
+ LiteralTestUtil::ExpectEqual(*result, *expected);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
new file mode 100644
index 0000000000..58f8dc92cc
--- /dev/null
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -0,0 +1,295 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/instruction_fusion.h"
+
+#include <algorithm>
+#include <list>
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+bool IsExpensive(const HloInstruction& instruction) {
+ switch (instruction.opcode()) {
+ // Cheap instructions.
+ case HloOpcode::kAbs:
+ case HloOpcode::kAdd:
+ case HloOpcode::kBitcast:
+ case HloOpcode::kBroadcast:
+ case HloOpcode::kCeil:
+ case HloOpcode::kClamp:
+ case HloOpcode::kConcatenate:
+ case HloOpcode::kConstant:
+ case HloOpcode::kConvert:
+ case HloOpcode::kCopy:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kEq:
+ case HloOpcode::kFloor:
+ case HloOpcode::kGe:
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kGt:
+ case HloOpcode::kInfeed:
+ case HloOpcode::kLe:
+ case HloOpcode::kLogicalAnd:
+ case HloOpcode::kLogicalNot:
+ case HloOpcode::kLogicalOr:
+ case HloOpcode::kLt:
+ case HloOpcode::kMaximum:
+ case HloOpcode::kMinimum:
+ case HloOpcode::kMultiply:
+ case HloOpcode::kNe:
+ case HloOpcode::kNegate:
+ case HloOpcode::kPad:
+ case HloOpcode::kReshape:
+ case HloOpcode::kReverse:
+ case HloOpcode::kSelect:
+ case HloOpcode::kSign:
+ case HloOpcode::kSlice:
+ case HloOpcode::kSubtract:
+ case HloOpcode::kTranspose:
+ case HloOpcode::kTuple:
+ return false;
+
+ // Expensive instructions.
+ case HloOpcode::kCall:
+ case HloOpcode::kConvolution:
+ case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kCustomCall:
+ case HloOpcode::kDivide:
+ case HloOpcode::kDot:
+ case HloOpcode::kExp:
+ case HloOpcode::kFusion:
+ case HloOpcode::kIndex:
+ case HloOpcode::kLog:
+ case HloOpcode::kMap:
+ case HloOpcode::kParameter:
+ case HloOpcode::kPower:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kRemainder:
+ case HloOpcode::kRng:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kSort:
+ case HloOpcode::kTanh:
+ case HloOpcode::kTrace:
+ case HloOpcode::kUpdate:
+ case HloOpcode::kWhile:
+ case HloOpcode::kSend:
+ case HloOpcode::kRecv:
+ return true;
+ }
+}
+
+bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) {
+ return !(producer->users().size() == 1 &&
+ producer->users().count(consumer) == 1);
+}
+
+StatusOr<bool> InstructionFusion::Run(HloModule* module) {
+ bool changed = false;
+ for (auto& computation : module->computations()) {
+ computation_ = computation.get();
+
+ // We want to be able to remove arbitrary instructions from the post order
+ // and also compare positions of instructions in the post order. To make
+ // this possible, create vector of instructions in post order and create a
+ // map from HloInstruction* to the instruction's index in the vector. An
+ // instruction is "removed" from the vector by setting it's element to
+ // nullptr.
+ std::list<HloInstruction*> post_order_list =
+ computation_->MakeInstructionPostOrder();
+ std::vector<HloInstruction*> post_order(post_order_list.begin(),
+ post_order_list.end());
+ tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
+ for (int i = 0; i < post_order.size(); ++i) {
+ InsertOrDie(&post_order_index, post_order[i], i);
+ }
+
+ // Instruction fusion effectively fuses edges in the computation graph
+ // (producer instruction -> consumer instruction) so we iterate over all
+ // edges. When we fuse an edge, we create a copy of the producer inside the
+ // fusion instruction.
+ while (!post_order.empty()) {
+ // We want to iterate in reverse post order, so remove from the back of
+ // the vector.
+ HloInstruction* instruction = post_order.back();
+ post_order.pop_back();
+
+ // Instructions are "removed" from the post order by nulling out the
+ // element in the vector, so if the pointer is null, continue to the next
+ // instruction in the sort.
+ if (instruction == nullptr) {
+ continue;
+ }
+
+ // Remove instruction from the index map to ensure the vector and map stay
+ // consistent.
+ post_order_index.erase(instruction);
+
+ if (!instruction->IsFusable() &&
+ instruction->opcode() != HloOpcode::kFusion) {
+ continue;
+ }
+
+ // Consider each operand of this instruction for fusion into this
+ // instruction. We want to consider the operands in a particular order to
+ // avoid created duplicate instruction clones in the fusion instruction.
+ // For example, consider the following expression:
+ //
+ // A = ...
+ // B = op(A)
+ // C = op(A, B)
+ //
+ // If we are considering the operands of C for fusion into C. We might
+ // fuse A or B first. If we fuse A first, we get:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // C' = op(A', B) }
+ //
+ // Where A' and C' are clones of A and C, respectively. Now only B is an
+ // operand of the fusion instruction C_fusion, so then we fuse B:
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // B' = op(A)
+ // C' = op(A', B') }
+ //
+ // Now A is an operand of C_fusion again, so we then fuse A (again!):
+ //
+ // A = ...
+ // B = op(A)
+ // C_fusion = { A' = ...
+ // A" = ..
+ // B' = op(A")
+ // C' = op(A', B') }
+ //
+ // We prevent this duplication by considering the operands in the reverse
+ // order they appear in the instruction post order. In the example, this
+ // ensures that B will be considered before A.
+ //
+ // We store the original indices of the operands to pass to ShouldFuse.
+ std::vector<int64> sorted_operand_numbers(instruction->operands().size());
+ std::iota(std::begin(sorted_operand_numbers),
+ std::end(sorted_operand_numbers), 0);
+ std::sort(
+ sorted_operand_numbers.begin(), sorted_operand_numbers.end(),
+ [&](int64 i, int64 j) {
+ // Instructions with higher indices in the post order come
+ // first.
+ return (
+ FindOrDie(post_order_index, instruction->mutable_operand(i)) >
+ FindOrDie(post_order_index, instruction->mutable_operand(j)));
+ });
+
+ for (int64 i : sorted_operand_numbers) {
+ HloInstruction* operand = instruction->mutable_operand(i);
+ if (operand->IsFusable() && ShouldFuse(instruction, i)) {
+ HloInstruction* fusion_instruction = Fuse(operand, instruction);
+
+ // Fusing an instruction into a fusion instruction can change the
+ // operand set of the fusion instruction. For simplicity just push the
+ // instruction to the top of the post_order and reconsider it for
+ // further fusion in the next iteration of the outer loop.
+ post_order.push_back(fusion_instruction);
+ InsertOrDie(&post_order_index, fusion_instruction,
+ post_order.size() - 1);
+ changed = true;
+
+ if (operand->user_count() == 0) {
+ // Operand is now dead. Remove from post order by setting it's
+ // location to nullptr.
+ post_order[FindOrDie(post_order_index, operand)] = nullptr;
+ post_order_index.erase(operand);
+
+ // Remove from computation.
+ computation_->RemoveInstruction(operand);
+ }
+ break;
+ }
+ }
+ }
+ }
+ return changed;
+}
+
+HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
+ HloInstruction* consumer) {
+ HloInstruction* fusion_instruction;
+
+ VLOG(2) << "Fusing " << producer << " into " << consumer;
+
+ if (consumer->opcode() == HloOpcode::kFusion) {
+ fusion_instruction = consumer;
+ } else {
+ fusion_instruction =
+ computation_->AddInstruction(HloInstruction::CreateFusion(
+ consumer->shape(), ChooseKind(producer, consumer), consumer));
+ computation_->ReplaceInstruction(consumer, fusion_instruction);
+ }
+ fusion_instruction->FuseInstruction(producer);
+
+ return fusion_instruction;
+}
+
+bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
+ int64 operand_index) {
+ HloInstruction* producer = consumer->mutable_operand(operand_index);
+ // Cost condition: don't duplicate expensive instructions.
+ if (FusionWouldDuplicate(producer, consumer) &&
+ (IsExpensive(*producer) || !may_duplicate_)) {
+ return false;
+ }
+
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
+ consumer->fusion_kind() != HloInstruction::FusionKind::kInput) {
+ return false;
+ }
+
+ // Cost condition: not fuse (expensive producers) and (consumers who reuse
+ // operand elements).
+ if (consumer->ReusesOperandElements(operand_index) &&
+ IsExpensive(*producer)) {
+ return false;
+ }
+
+ if (producer->CouldBeBitcast() &&
+ // We can't fuse parameters anyhow, so we leave the user unfused to become
+ // a bitcast. If the operand is not a parameter, we would break a
+ // potential fusion to make it a bitcast, which is not so clear a win.
+ producer->operand(0)->opcode() == HloOpcode::kParameter) {
+ return false;
+ }
+
+ return true;
+}
+
+HloInstruction::FusionKind InstructionFusion::ChooseKind(
+ const HloInstruction* producer, const HloInstruction* consumer) {
+ return HloInstruction::FusionKind::kLoop;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
new file mode 100644
index 0000000000..902df2dcd0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Returns true if the computation of the given instruction is significantly
+// more expensive than just writing all the values of the instructions' result
+// array. Expensive operations should not be duplicated.
+bool IsExpensive(const HloInstruction& instruction);
+
+// Returns true if fusing producer into consumer would cause producer to be
+// duplicated. This is the case if producer has uses other than consumer.
+bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
+
+// HLO pass which performs instruction fusion. Instructions are fused
+// "vertically", meaning producing instructions are fused into their consumers
+// with the intent that the loops which compute their values will be fused in
+// code generation. Derived classes define ShouldFuse method to select which
+// instructions to fuse.
+class InstructionFusion : public HloPass {
+ public:
+ explicit InstructionFusion(bool may_duplicate = true)
+ : HloPass("fusion"), may_duplicate_(may_duplicate) {}
+ ~InstructionFusion() override {}
+
+ // Run instruction fusion on the given computation. Returns whether the
+ // computation was changed (instructions were fused).
+ StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
+ // Returns whether the given producer instruction should be fused into the
+ // given consumer instruction. producer is necessarily an operand of consumer.
+ // Derived classes should define this method to specify which instructions
+ // should be fused. `operand_index` is which operand of the consumer the
+ // producer is.
+ //
+ // Instructions are traversed in reverse post order (computation root to
+ // leaves). This method is called for each operand of the instruction (where
+ // the operand is 'producer' and the instruction is 'consumer')
+ //
+ // Subtypes can override this with target-specific heuristics.
+ virtual bool ShouldFuse(HloInstruction* consumer, int64 operand_index);
+
+ // Chooses a fusion kind for `producer` and `consumer`.
+ // Default method chooses `kLoop`.
+ virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
+ const HloInstruction* consumer);
+
+ // Current HloComputation instance the loop fuser is traversing.
+ HloComputation* computation_;
+
+ private:
+ HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);
+
+ // Returns whether we may duplicate an instruction if we want to fuse it.
+ bool may_duplicate_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(InstructionFusion);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
new file mode 100644
index 0000000000..2e3742ed75
--- /dev/null
+++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/instruction_fusion.h"
+
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+
+using InstructionFusionTest = HloTestBase;
+
+TEST_F(InstructionFusionTest,
+ CostlyProducerAndOperandElementReusingConsumerNotFused) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
+ HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
+ HloInstruction* broadcast2 =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(S32, {1}), exp1, {0}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(broadcast2, computation->root_instruction());
+ EXPECT_TRUE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ EXPECT_EQ(broadcast2, computation->root_instruction());
+}
+
+TEST_F(InstructionFusionTest,
+ NonCostlyProducerAndOperandElementReusingConsumerFused) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
+ HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0));
+ HloInstruction* broadcast2 =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(S32, {1}), negate1, {0}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(broadcast2, computation->root_instruction());
+ EXPECT_TRUE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
+}
+
+TEST_F(InstructionFusionTest,
+ CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
+ HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
+ HloInstruction* reshape2 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), exp1));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(reshape2, computation->root_instruction());
+ EXPECT_TRUE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
+}
+
+TEST_F(InstructionFusionTest,
+ CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) {
+ HloComputation::Builder builder(TestName());
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0(5)));
+ HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0));
+ HloInstruction* transpose2 = builder.AddInstruction(
+ HloInstruction::CreateTranspose(ShapeUtil::MakeShape(S32, {}), exp1, {}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(transpose2, computation->root_instruction());
+ EXPECT_TRUE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+ EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(reshape1, computation->root_instruction());
+ EXPECT_FALSE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {1, 1}), param0));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(reshape1, computation->root_instruction());
+ EXPECT_FALSE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+}
+
+TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "0"));
+ auto transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {}), param0, {}));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(transpose1, computation->root_instruction());
+ EXPECT_FALSE(
+ InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
new file mode 100644
index 0000000000..a8f2a6b89c
--- /dev/null
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -0,0 +1,1334 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
+
+#include <algorithm>
+#include <deque>
+#include <functional>
+#include <map>
+#include <memory>
+#include <numeric>
+#include <ostream>
+#include <set>
+#include <string>
+#include <tuple>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+
+std::ostream& operator<<(std::ostream& out,
+ const LayoutConstraint& constraint) {
+ out << constraint.ToString();
+ return out;
+}
+
+BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout,
+ const LogicalBuffer& buffer)
+ : layout_(layout), buffer_(&buffer) {
+ CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok());
+}
+
+string BufferLayoutConstraint::ToString() const {
+ return tensorflow::strings::Printf("BufferLayoutConstraint %s: %s",
+ buffer_->ToString().c_str(),
+ LayoutUtil::HumanString(layout_).c_str());
+}
+
+OperandLayoutConstraint::OperandLayoutConstraint(
+ const ShapeLayout& shape_layout, const HloInstruction* instruction,
+ int64 operand_no)
+ : shape_layout_(shape_layout),
+ instruction_(instruction),
+ operand_no_(operand_no) {
+ CHECK(shape_layout_.LayoutIsSet());
+ CHECK(ShapeUtil::Compatible(shape_layout.shape(),
+ instruction->operand(operand_no)->shape()));
+}
+
+string OperandLayoutConstraint::ToString() const {
+ return tensorflow::strings::Printf(
+ "OperandLayoutConstraint %s, operand %lld: %s",
+ instruction_->name().c_str(), operand_no_,
+ shape_layout_.ToString().c_str());
+}
+
+string ResultLayoutConstraint::ToString() const {
+ return tensorflow::strings::Printf("ResultLayoutConstraint: %s",
+ shape_layout_.ToString().c_str());
+}
+
+LayoutConstraints::LayoutConstraints(
+ const TuplePointsToAnalysis& points_to_analysis,
+ const HloComputation* computation)
+ : points_to_analysis_(points_to_analysis), computation_(computation) {
+ // Gather all array-shaped logical buffers into unconstrained_buffer_ids.
+ for (auto& buffer : points_to_analysis_.logical_buffers()) {
+ if (buffer->IsArray()) {
+ unconstrained_buffer_ids_.insert(buffer->id());
+ }
+ }
+}
+
+bool LayoutConstraints::OperandBufferForwarded(
+ const HloInstruction* instruction, int64 operand_no) const {
+ // The operand is potentially forwarded if the intersection of points-to sets
+ // of the operand and the instruction is non-empty.
+ auto output_buffers =
+ points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet();
+ auto operand_buffers =
+ points_to_analysis_.GetPointsToSet(instruction->operand(operand_no))
+ .CreateFlattenedSet();
+ std::vector<const LogicalBuffer*> intersection;
+ std::set_intersection(output_buffers.begin(), output_buffers.end(),
+ operand_buffers.begin(), operand_buffers.end(),
+ std::back_inserter(intersection));
+ return !intersection.empty();
+}
+
+Status LayoutConstraints::SetBufferLayout(const Layout& layout,
+ const LogicalBuffer& buffer) {
+ VLOG(3) << "SetBufferLayout : " << buffer << " : "
+ << LayoutUtil::HumanString(layout);
+
+ TF_RETURN_IF_ERROR(points_to_analysis_.VerifyBuffer(buffer));
+ if (!buffer.IsArray()) {
+ return FailedPrecondition(
+ "Layout of buffer %s cannot be constrained because buffer is not "
+ "array-shaped, has shape: %s",
+ buffer.ToString().c_str(),
+ ShapeUtil::HumanString(buffer.shape()).c_str());
+ }
+ TF_RETURN_IF_ERROR(
+ LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()));
+
+ const Layout* curr_layout = BufferLayout(buffer);
+ if (curr_layout != nullptr) {
+ if (!LayoutUtil::Equal(*curr_layout, layout)) {
+ return FailedPrecondition(
+ "Buffer %s already has the layout constraint %s, cannot add "
+ "incompatible constraint %s",
+ buffer.ToString().c_str(),
+ LayoutUtil::HumanString(*curr_layout).c_str(),
+ LayoutUtil::HumanString(layout).c_str());
+ }
+ // New constraint matches existing constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ auto new_constraint_it = buffer_constraints_.insert(
+ {&buffer, BufferLayoutConstraint(layout, buffer)});
+ added_constraints_.push_back(&new_constraint_it.first->second);
+
+ // Remove buffer from the set of unconstrained buffers.
+ TF_RET_CHECK(unconstrained_buffer_ids_.count(buffer.id()) == 1);
+ unconstrained_buffer_ids_.erase(buffer.id());
+
+ return Status::OK();
+}
+
+Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout,
+ const HloInstruction* instruction,
+ int64 operand_no) {
+ VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand "
+ << operand_no << " : "
+ << ShapeUtil::HumanStringWithLayout(shape_with_layout);
+
+ const ShapeLayout* curr_shape_layout = OperandLayout(instruction, operand_no);
+ if (curr_shape_layout != nullptr) {
+ if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
+ return FailedPrecondition(
+ "Operand %lld of instruction %s already has a layout constraint "
+ "%s, cannot add incompatible constraint %s",
+ operand_no, instruction->name().c_str(),
+ curr_shape_layout->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+ // New constraint matches existing constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ // If any buffers in the operand occur in the output of the instruction, then
+ // return an error. This case is not handled because such a constraint changes
+ // layouts beyond this immediate use and is complicated to handle.
+ if (OperandBufferForwarded(instruction, operand_no)) {
+ return FailedPrecondition(
+ "Cannot constraint layout of operand %lld of instruction %s "
+ "because instruction forwards operand's LogicalBuffer(s)",
+ operand_no, instruction->name().c_str());
+ }
+
+ auto key = std::make_pair(instruction, operand_no);
+ auto new_constraint_it = operand_constraints_.insert(
+ {key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction,
+ operand_no)});
+ added_constraints_.push_back(&new_constraint_it.first->second);
+
+ return Status::OK();
+}
+
+Status LayoutConstraints::SetArrayOperandLayout(
+ const Layout& layout, const HloInstruction* instruction, int64 operand_no) {
+ const HloInstruction* operand = instruction->operand(operand_no);
+ TF_RET_CHECK(ShapeUtil::IsArray(operand->shape()));
+ Shape shape(operand->shape());
+ *shape.mutable_layout() = layout;
+ TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape));
+ return SetOperandLayout(shape, instruction, operand_no);
+}
+
+Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) {
+ VLOG(3) << "SetResultLayout : "
+ << ShapeUtil::HumanStringWithLayout(shape_with_layout);
+
+ const ShapeLayout* curr_shape_layout = ResultLayout();
+ if (curr_shape_layout != nullptr) {
+ if (!curr_shape_layout->MatchesLayoutInShape(shape_with_layout)) {
+ return FailedPrecondition(
+ "Result of computation %s already has the layout constraint %s, "
+ "cannot add incompatible constraint %s",
+ computation_->name().c_str(), curr_shape_layout->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+ // New constraint matches existing constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ result_constraint_.reset(
+ new ResultLayoutConstraint(ShapeLayout(shape_with_layout)));
+ added_constraints_.push_back(result_constraint_.get());
+
+ return Status::OK();
+}
+
+Status LayoutConstraints::SetInstructionLayout(
+ const Shape& shape_with_layout, const HloInstruction* instruction) {
+ VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", "
+ << ShapeUtil::HumanStringWithLayout(shape_with_layout);
+
+ if (!ShapeUtil::Compatible(shape_with_layout, instruction->shape())) {
+ return FailedPrecondition(
+ "Instruction %s of shape %s cannot be assigned incompatible layout %s",
+ instruction->name().c_str(),
+ ShapeUtil::HumanString(instruction->shape()).c_str(),
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+
+ // Create a BufferLayoutConstraint for each array shape in the output of the
+ // instruction.
+ return ShapeUtil::ForEachSubshape(
+ shape_with_layout,
+ [this, instruction](const Shape& subshape,
+ const ShapeIndex& index) -> Status {
+ // The precondition for this method is that the instruction defines all
+ // buffers in its output.
+ auto buffers =
+ points_to_analysis_.GetPointsToSet(instruction).element(index);
+ CHECK_EQ(1, buffers.size());
+ CHECK_EQ(buffers[0]->instruction(), instruction);
+
+ if (ShapeUtil::IsArray(subshape)) {
+ return SetBufferLayout(subshape.layout(), *buffers[0]);
+ } else {
+ return Status::OK();
+ }
+ });
+}
+
+const Layout* LayoutConstraints::BufferLayout(
+ const LogicalBuffer& buffer) const {
+ auto it = buffer_constraints_.find(&buffer);
+ return it == buffer_constraints_.end() ? nullptr : &it->second.layout();
+}
+
+const ShapeLayout* LayoutConstraints::OperandLayout(
+ const HloInstruction* instruction, int64 operand_no) const {
+ auto it = operand_constraints_.find(std::make_pair(instruction, operand_no));
+ return it == operand_constraints_.end() ? nullptr
+ : &it->second.shape_layout();
+}
+
+const ShapeLayout* LayoutConstraints::ResultLayout() const {
+ return result_constraint_ ? &result_constraint_->shape_layout() : nullptr;
+}
+
+string LayoutConstraints::ToString() const {
+ string output;
+ tensorflow::strings::StrAppend(&output, "LayoutConstraints for computation ",
+ computation_->name(), ":\n");
+ for (auto* instruction : computation_->MakeInstructionPostOrder()) {
+ tensorflow::strings::StrAppend(&output, " ", instruction->ToShortString(),
+ "\n");
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ if (OperandLayout(instruction, i) != nullptr) {
+ tensorflow::strings::StrAppend(
+ &output, " operand (", i, "): ",
+ OperandLayout(instruction, i)->ToString(), "\n");
+ }
+ }
+ for (const LogicalBuffer* buffer :
+ points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
+ if (BufferLayout(*buffer) != nullptr) {
+ tensorflow::strings::StrAppend(
+ &output, " ", buffer->ToString(), " : ",
+ LayoutUtil::HumanString(*BufferLayout(*buffer)), "\n");
+ }
+ }
+ }
+
+ if (ResultLayout() != nullptr) {
+ tensorflow::strings::StrAppend(&output, " => ", ResultLayout()->ToString(),
+ "\n");
+ }
+ return output;
+}
+
+Status LayoutAssignment::AddMandatoryConstraints(
+ const ComputationLayout& computation_layout, HloComputation* computation,
+ LayoutConstraints* constraints) {
+ VLOG(3) << "Adding mandatory layout constraints to computation "
+ << computation->name();
+
+ // Constrain layouts of instructions which define values with pre-existing
+ // layouts.
+ for (auto& instruction : computation->instructions()) {
+ Shape const* shape_with_layout = nullptr;
+ if (instruction->opcode() == HloOpcode::kConstant) {
+ // Constant layouts must match the layout of their literal.
+ shape_with_layout = &instruction->literal().shape();
+ } else if (instruction->opcode() == HloOpcode::kInfeed) {
+ // Infeed layouts must match the layout of the original inserted
+ // instruction.
+ // TODO(b/31425034): Change infeeds to be more like parameters, with
+ // shapes in the ComputationLayout.
+ shape_with_layout = &instruction->shape();
+ } else if (instruction->opcode() == HloOpcode::kParameter) {
+ // Parameter layouts must match the respective layout in
+ // ComputationLayout.
+ shape_with_layout =
+ &computation_layout.parameter_layout(instruction->parameter_number())
+ .shape();
+ }
+ if (shape_with_layout != nullptr) {
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(*shape_with_layout,
+ instruction.get()));
+ }
+ }
+
+ // Constrain layouts of instructions which call computations which have
+ // already been assigned layouts. Instructions which call computations in a
+ // parallel element-wise context (eg, map or reduce) do not need layout
+ // constraints because they operate on scalars.
+ for (auto& instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCall) {
+ // kCall instruction operands and output must match the ComputationLayout
+ // of the called computation.
+ const ComputationLayout& called_computation_layout =
+ FindOrDie(computation_layouts_, instruction->to_apply());
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ called_computation_layout.result_layout().shape(),
+ instruction.get()));
+ TF_RET_CHECK(instruction->operand_count() ==
+ called_computation_layout.parameter_count());
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ called_computation_layout.parameter_layout(i).shape(),
+ instruction.get(), i));
+ }
+ } else if (instruction->opcode() == HloOpcode::kWhile) {
+ // Layout of input and output of kWhile instruction must be equal and must
+ // match both input and output of body computation. Also, the input of
+ // condition computation must match kWhile layout.
+ HloComputation* body = instruction->while_body();
+ HloComputation* condition = instruction->while_condition();
+ const HloInstruction* init = instruction->operand(0);
+ const ComputationLayout& body_layout =
+ FindOrDie(computation_layouts_, body);
+ const ComputationLayout& condition_layout =
+ FindOrDie(computation_layouts_, condition);
+
+ // Check a few invariants irrespective of layout.
+ CHECK_EQ(1, instruction->operand_count());
+ CHECK_EQ(1, body->num_parameters());
+ CHECK_EQ(1, condition->num_parameters());
+ DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
+ body_layout.parameter_shape(0)));
+ DCHECK(ShapeUtil::Compatible(body_layout.result_shape(),
+ condition_layout.parameter_shape(0)));
+ DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
+
+ // Return error if earlier layout assignment of the embedded computations
+ // has produced conflicting layouts.
+ if (!ShapeUtil::Equal(body_layout.result_shape(),
+ body_layout.parameter_shape(0))) {
+ return InternalError(
+ "Parameter and result of body computation %s of while instruction "
+ "%s have different layouts: %s vs %s",
+ body->name().c_str(), instruction->name().c_str(),
+ ShapeUtil::HumanString(body_layout.result_shape()).c_str(),
+ ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str());
+ }
+ if (!ShapeUtil::Equal(body->root_instruction()->shape(),
+ condition->parameter_instruction(0)->shape())) {
+ return InternalError(
+ "Parameter of condition computation %s of while instruction "
+ "%s does not match body computation %s result: %s vs %s",
+ condition->name().c_str(), instruction->name().c_str(),
+ body->name().c_str(),
+ ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(),
+ ShapeUtil::HumanString(body_layout.result_shape()).c_str());
+ }
+
+ // Constrain the output and the operand of the while instruction to match
+ // the computations.
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ body_layout.result_shape(), instruction.get()));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ body_layout.result_shape(), instruction.get(), 0));
+ } else if (instruction->opcode() == HloOpcode::kCustomCall) {
+ // Add constraints for kCustomCall instruction operands and instructions.
+ // For now we only support row major layouts for all inputs and outputs.
+ auto row_major_shape = [](const Shape& old_shape) {
+ Shape new_shape(old_shape);
+ std::vector<int64> dimension_order(new_shape.dimensions_size());
+ std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
+ *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
+ return new_shape;
+ };
+
+ Shape result_shape(row_major_shape(instruction->shape()));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(result_shape, instruction.get()));
+ for (int64 i = 0; i < instruction->operand_count(); ++i) {
+ const Shape& operand_shape = instruction->operand(i)->shape();
+ // Opaque operands don't get a layout constraint.
+ if (ShapeUtil::IsOpaque(operand_shape)) {
+ continue;
+ }
+
+ Shape row_major_operand_shape(row_major_shape(operand_shape));
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ row_major_operand_shape, instruction.get(), i));
+ }
+ }
+ }
+
+ // Finally set the result layout to match ComputationLayout.
+ return constraints->SetResultLayout(
+ computation_layout.result_layout().shape());
+}
+
+namespace {
+
+// The operands of a call must match the layouts of parameters in the
+// ComputationLayout, and the call instruction itself must match the result
+// layout in the ComputationLayout.
+Status CheckCallLayout(HloInstruction* call,
+ const ComputationLayout& computation_layout) {
+ HloComputation* computation = call->to_apply();
+ TF_RET_CHECK(computation->num_parameters() == call->operand_count());
+ for (int64 i = 0; i < computation->num_parameters(); ++i) {
+ TF_RET_CHECK(computation_layout.parameter_layout(i).MatchesLayoutInShape(
+ call->operand(i)->shape()));
+ }
+ TF_RET_CHECK(
+ computation_layout.result_layout().MatchesLayoutInShape(call->shape()));
+ return Status::OK();
+}
+
+// Custom calls have fixed input and output layouts.
+Status CheckCustomCallLayout(HloInstruction* custom_call) {
+ for (const HloInstruction* operand : custom_call->operands()) {
+ TF_RET_CHECK(
+ LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+ }
+ TF_RET_CHECK(
+ LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout()));
+ return Status::OK();
+}
+
+// For a while instruction, all the following layouts must be the same:
+// (1) init operand
+// (2) condition computation parameter
+// (3) body computation parameter
+// (4) body computation result
+// (5) while instruction result
+Status CheckWhileLayout(HloInstruction* while_inst,
+ const ComputationLayout& condition_computation_layout,
+ const ComputationLayout& body_computation_layout) {
+ auto init_shape = while_inst->operand(0)->shape();
+ TF_RET_CHECK(
+ condition_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ init_shape));
+ TF_RET_CHECK(body_computation_layout.parameter_layout(0).MatchesLayoutInShape(
+ init_shape));
+ TF_RET_CHECK(
+ body_computation_layout.result_layout().MatchesLayoutInShape(init_shape));
+ TF_RET_CHECK(
+ LayoutUtil::LayoutsInShapesEqual(init_shape, while_inst->shape()));
+ return Status::OK();
+}
+
+// Fusion parameters must match the layout of the fusion instructions operands,
+// and the root of the fusion expression must match the layout of the fusion
+// instruction.
+Status CheckFusionLayout(HloInstruction* fusion) {
+ TF_RET_CHECK(HloOpcode::kFusion == fusion->opcode());
+
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ fusion->shape(), fusion->fused_expression_root()->shape()));
+ for (int64 i = 0; i < fusion->operand_count(); ++i) {
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ fusion->fused_parameter(i)->shape(), fusion->operand(i)->shape()));
+ }
+ return Status::OK();
+}
+
+// The layout of a parameter must match the respective layout in the
+// computation's ComputationLayout.
+Status CheckParameterLayout(HloInstruction* parameter,
+ const ComputationLayout& computation_layout) {
+ const ShapeLayout& parameter_layout =
+ computation_layout.parameter_layout(parameter->parameter_number());
+ if (!parameter_layout.MatchesLayoutInShape(parameter->shape())) {
+ return InternalError(
+ "parameter instruction %s does not match layout of computation "
+ "shape: %s",
+ parameter->ToString().c_str(), parameter_layout.ToString().c_str());
+ }
+ return Status::OK();
+}
+
+// The layout of a constant instruction must match the layout of its literal.
+Status CheckConstantLayout(HloInstruction* constant) {
+ if (!LayoutUtil::LayoutsInShapesEqual(constant->literal().shape(),
+ constant->shape())) {
+ return InternalError(
+ "constant instruction %s does not match the layout of its literal %s",
+ constant->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(constant->literal().shape()).c_str());
+ }
+ return Status::OK();
+}
+
+// Check that all layouts in the module have been set and satisfy all necessary
+// conditions.
+Status CheckLayouts(
+ HloModule* module,
+ const std::map<HloComputation*, ComputationLayout>& computation_layouts) {
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+ for (auto& computation : module->computations()) {
+ for (auto& instruction : computation->instructions()) {
+ // Verify every instruction has a layout and the layout is valid for the
+ // shape.
+ TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape()));
+
+ // Use points-to analysis to verify that every subshape element in the
+ // output of the instruction matches the layout of the logical buffer
+ // which could be the source of the subshape value.
+ const PointsToSet& points_to_set =
+ points_to_analysis->GetPointsToSet(instruction.get());
+ TF_RETURN_IF_ERROR(points_to_set.ForEachElement(
+ [&instruction](
+ ShapeIndex index, bool is_leaf,
+ const std::vector<const LogicalBuffer*>& buffers) -> Status {
+ if (is_leaf) {
+ const Shape& instruction_subshape =
+ ShapeUtil::GetSubshape(instruction->shape(), index);
+ for (const LogicalBuffer* buffer : buffers) {
+ if (!ShapeUtil::Equal(instruction_subshape, buffer->shape())) {
+ return InternalError(
+ "Layout of instruction %s at index {%s} does not match "
+ "source LogicalBuffer %s: %s vs %s",
+ instruction->name().c_str(),
+ tensorflow::str_util::Join(index, ",").c_str(),
+ buffer->ToString().c_str(),
+ ShapeUtil::HumanStringWithLayout(instruction_subshape)
+ .c_str(),
+ ShapeUtil::HumanStringWithLayout(buffer->shape())
+ .c_str());
+ }
+ }
+ }
+ return Status::OK();
+ }));
+
+ // Verify instructions that have special layout constraints.
+ switch (instruction->opcode()) {
+ case HloOpcode::kCall:
+ TF_RETURN_IF_ERROR(CheckCallLayout(
+ instruction.get(),
+ FindOrDie(computation_layouts, instruction->to_apply())));
+ break;
+ case HloOpcode::kCustomCall:
+ TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction.get()));
+ break;
+ case HloOpcode::kFusion:
+ TF_RETURN_IF_ERROR(CheckFusionLayout(instruction.get()));
+ break;
+ case HloOpcode::kParameter:
+ TF_RETURN_IF_ERROR(CheckParameterLayout(
+ instruction.get(),
+ FindOrDie(computation_layouts, instruction->parent())));
+ break;
+ case HloOpcode::kConstant:
+ TF_RETURN_IF_ERROR(CheckConstantLayout(instruction.get()));
+ break;
+ case HloOpcode::kWhile:
+ TF_RETURN_IF_ERROR(CheckWhileLayout(
+ instruction.get(),
+ FindOrDie(computation_layouts, instruction->while_condition()),
+ FindOrDie(computation_layouts, instruction->while_body())));
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ // Finally verify the result layout matches the layout of the entry
+ // computation root.
+ TF_RET_CHECK(ShapeUtil::Equal(
+ module->entry_computation()->root_instruction()->shape(),
+ FindOrDie(computation_layouts, module->entry_computation())
+ .result_layout()
+ .shape()));
+
+ return Status::OK();
+}
+
+} // namespace
+
+LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout)
+ : HloPass("layout-assignment"),
+ entry_computation_layout_(entry_computation_layout) {
+ VLOG(1) << "entry computation layout given to layout assignment: "
+ << entry_computation_layout_->ToString();
+ // Layouts of all parameter instructions must be set.
+ for (const ShapeLayout& parameter_layout :
+ entry_computation_layout_->parameter_layouts()) {
+ CHECK(parameter_layout.LayoutIsSet());
+ }
+ // If the result layout is not set, then choose the default.
+ // TODO(b/29118294): Choose a better layout in this case.
+ if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
+ entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
+ }
+}
+
+namespace {
+
+// Given a pemutation of `{0, 1, ..., n}` `indices`, returns a permutation of
+// `{0, 1, ..., n - to_delete.size() + to_insert.size()}` by deleting the
+// indices `to_delete` wherever in `indices` they are, and inserting the indices
+// `to_insert` arbitrarily at the back.
+tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>
+DeleteAndInsertIndices(
+ std::vector<int64> to_delete, std::vector<int64> to_insert,
+ tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64> indices) {
+ std::sort(to_delete.begin(), to_delete.end(), std::greater<int64>());
+ std::sort(to_insert.begin(), to_insert.end(), std::less<int64>());
+ for (auto index : to_delete) {
+ auto i = indices.begin();
+ while (i != indices.end()) {
+ if (*i == index) {
+ i = indices.erase(i);
+ } else {
+ if (*i > index) {
+ (*i)--;
+ }
+ ++i;
+ }
+ }
+ }
+ for (auto index : to_insert) {
+ for (auto i = indices.begin(); i != indices.end(); ++i) {
+ if (*i >= index) {
+ (*i)++;
+ }
+ }
+ indices.Add(index);
+ }
+ return indices;
+}
+
+} // namespace
+
+std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
+ const Layout& output_layout, const HloInstruction* instruction,
+ int64 operand_no) {
+ const HloInstruction* operand = instruction->operand(operand_no);
+
+ CHECK(ShapeUtil::IsArray(instruction->shape()) &&
+ ShapeUtil::IsArray(operand->shape()));
+
+ if (instruction->IsElementwiseOnOperand(operand_no) &&
+ !ShapeUtil::IsScalar(operand->shape()) &&
+ ShapeUtil::Rank(operand->shape()) ==
+ ShapeUtil::Rank(instruction->shape())) {
+ // Assign operands the same layout as the instruction, so that
+ // 1) the elementwise operation can reuse its operand's buffer, and
+ // 2) the input and output elements can reuse the same linear index.
+ //
+ // TODO(jingyue): Other operations, such as kSlice and kConcat, can benefit
+ // from assigning the same layout to input and output.
+ return MakeUnique<Layout>(output_layout);
+ }
+
+ if (instruction->opcode() == HloOpcode::kReshape) {
+ // Pick the operand layout that makes the reshape a bitcast. If the reshape
+ // only inserts or deletes degenerate dimensions, we can easily compute the
+ // desired layout by accordingly inserting and deleting the elements in the
+ // minor-to-major list.
+ bool merely_inserts_or_deletes_1_sized_dims;
+ std::vector<int64> inserted_indices, deleted_indices;
+ std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices,
+ inserted_indices) =
+ instruction->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (merely_inserts_or_deletes_1_sized_dims) {
+ Layout operand_layout = LayoutUtil::MakeLayout(
+ AsInt64Slice(DeleteAndInsertIndices(inserted_indices, deleted_indices,
+ output_layout.minor_to_major())));
+ TF_CHECK_OK(
+ LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
+ return MakeUnique<Layout>(operand_layout);
+ }
+ }
+
+ if (instruction->opcode() == HloOpcode::kTranspose) {
+ // Pick the operand layout that makes the transpose a bitcast.
+ std::vector<int64> perm =
+ ComposePermutations(instruction->dimensions(),
+ AsInt64Slice(output_layout.minor_to_major()));
+ Layout operand_layout = LayoutUtil::MakeLayout(perm);
+ TF_CHECK_OK(
+ LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
+ return MakeUnique<Layout>(operand_layout);
+ }
+
+ return nullptr;
+}
+
+std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
+ const Layout& operand_layout, const HloInstruction* user,
+ int64 operand_no) {
+ const HloInstruction* operand = user->operand(operand_no);
+
+ CHECK(ShapeUtil::IsArray(user->shape()) &&
+ ShapeUtil::IsArray(operand->shape()));
+
+ if (user->IsElementwiseOnOperand(operand_no) &&
+ !ShapeUtil::IsScalar(operand->shape()) &&
+ ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(user->shape())) {
+ // Assign users the same layout as the operand.
+ return MakeUnique<Layout>(operand_layout);
+ }
+
+ if (user->opcode() == HloOpcode::kReshape) {
+ // Pick the user layout that makes the reshape a bitcast.
+ bool merely_inserts_or_deletes_1_sized_dims;
+ std::vector<int64> inserted_indices, deleted_indices;
+ std::tie(merely_inserts_or_deletes_1_sized_dims, deleted_indices,
+ inserted_indices) =
+ user->ReshapeMerelyInsertsOrDeletes1SizedDimensions();
+ if (merely_inserts_or_deletes_1_sized_dims) {
+ Layout user_layout = LayoutUtil::MakeLayout(AsInt64Slice(
+ DeleteAndInsertIndices(deleted_indices, inserted_indices,
+ operand_layout.minor_to_major())));
+ TF_CHECK_OK(
+ LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
+ return MakeUnique<Layout>(user_layout);
+ }
+ }
+
+ if (user->opcode() == HloOpcode::kTranspose) {
+ // Pick the user layout that makes the reshape a bitcast.
+ // To become a bitcast, the layouts need to satisfy
+ // collapsing_order * output_layout = input_layout
+ // so output_layout = inverse(collapsing_order) * input_layout
+ std::vector<int64> perm =
+ Permute(InversePermutation(user->dimensions()),
+ AsInt64Slice(operand_layout.minor_to_major()));
+ Layout user_layout = LayoutUtil::MakeLayout(perm);
+ TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
+ return MakeUnique<Layout>(user_layout);
+ }
+
+ return nullptr;
+}
+
+Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) {
+ // Gathers all initial constraints in a worklist and propagates them in
+ // depth-first order. DFS order seems to be better than BFS because a
+ // constraint is propagated as far as possible before propagating unrelated
+ // constraints which makes it less likely that conflicting constraints will be
+ // propagated to instructions. However, we should experiment with other orders
+ // too.
+ std::deque<const LayoutConstraint*> worklist;
+
+ // Lambda for moving newly added constraints to the worklist.
+ auto add_new_constraints_to_worklist = [constraints, &worklist]() {
+ // Add constraints to the front of the deque for DFS ordering.
+ for (auto* constraint : constraints->ConsumeAddedConstraints()) {
+ worklist.push_front(constraint);
+ }
+ };
+ add_new_constraints_to_worklist();
+
+ while (!worklist.empty()) {
+ const LayoutConstraint* layout_constraint = worklist.front();
+ worklist.pop_front();
+ VLOG(2) << "Propagating " << layout_constraint->ToString()
+ << " to its neighbors.";
+ if (auto* buffer_constraint =
+ dynamic_cast<const BufferLayoutConstraint*>(layout_constraint)) {
+ TF_RETURN_IF_ERROR(
+ PropagateBufferConstraint(*buffer_constraint, constraints));
+ } else if (auto* operand_constraint =
+ dynamic_cast<const OperandLayoutConstraint*>(
+ layout_constraint)) {
+ TF_RETURN_IF_ERROR(
+ PropagateOperandConstraint(*operand_constraint, constraints));
+ } else if (auto* result_constraint =
+ dynamic_cast<const ResultLayoutConstraint*>(
+ layout_constraint)) {
+ TF_RETURN_IF_ERROR(
+ PropagateResultConstraint(*result_constraint, constraints));
+ } else {
+ LOG(FATAL) << "Invalid constraint type: " << *layout_constraint;
+ }
+
+ add_new_constraints_to_worklist();
+ }
+ return Status::OK();
+}
+
+namespace {
+
+// Returns a vector containing all array-shaped uses (instruction and operand
+// number) of the given logical buffer or its aliases.
+std::vector<std::pair<const HloInstruction*, int64>> GetArrayUsesOfBuffer(
+ const LogicalBuffer& buffer,
+ const TuplePointsToAnalysis& points_to_analysis) {
+ CHECK(buffer.IsArray());
+ std::vector<std::pair<const HloInstruction*, int64>> uses;
+ for (const auto& buffer_alias : points_to_analysis.GetBufferAliases(buffer)) {
+ if (!ShapeUtil::IsArray(buffer_alias.instruction()->shape())) {
+ continue;
+ }
+ // This alias must be the top-level (index == {}) of the instruction's
+ // result because the instruction produces an array.
+ CHECK(buffer_alias.index().empty());
+
+ // Add all uses of the instruction's output.
+ for (const HloInstruction* user : buffer_alias.instruction()->users()) {
+ for (int64 operand_no :
+ user->OperandIndices(buffer_alias.instruction())) {
+ uses.emplace_back(user, operand_no);
+ }
+ }
+ }
+ return uses;
+}
+
+} // namespace
+
+Status LayoutAssignment::PropagateUseConstraintToDefs(
+ const ShapeLayout& shape_layout, const HloInstruction* instruction,
+ LayoutConstraints* constraints) {
+ // Try to set all logical buffers which may be sources of the given operand to
+ // match the given layout.
+ const PointsToSet& points_to_set =
+ constraints->points_to_analysis().GetPointsToSet(instruction);
+ return points_to_set.ForEachElement(
+ [this, &shape_layout, constraints](
+ const ShapeIndex& index, bool is_leaf,
+ const std::vector<const LogicalBuffer*>& buffers) -> Status {
+ if (is_leaf) {
+ for (const LogicalBuffer* buffer : buffers) {
+ if (constraints->BufferLayout(*buffer) == nullptr &&
+ ShapeUtil::IsArray(buffer->shape())) {
+ TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
+ ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(),
+ *buffer));
+ }
+ }
+ }
+ return Status::OK();
+ });
+}
+
+Status LayoutAssignment::PropagateOperandConstraint(
+ const OperandLayoutConstraint& operand_constraint,
+ LayoutConstraints* constraints) {
+ // Try to set the layout of the logical buffers in the given operand to match
+ // the constrained layout. This avoids copies.
+ TF_RETURN_IF_ERROR(
+ PropagateUseConstraintToDefs(operand_constraint.shape_layout(),
+ operand_constraint.operand(), constraints));
+
+ // For array-shaped operands and user instructions try to pick a minimum cost
+ // layout. For example, if the operand of a elementwise instruction is
+ // constained to a certain layout we want the output of the instruction to
+ // have the same layout.
+ const HloInstruction* operand = operand_constraint.operand();
+ const HloInstruction* user = operand_constraint.instruction();
+ if (!ShapeUtil::IsArray(operand->shape()) ||
+ !ShapeUtil::IsArray(user->shape())) {
+ return Status::OK();
+ }
+
+ // Only try to choose a low cost layout if the instruction 'user' defines its
+ // output (ie, doesn't forward a buffer from elsewhere).
+ if (constraints->OperandBufferForwarded(user,
+ operand_constraint.operand_no())) {
+ return Status::OK();
+ }
+ TF_ASSIGN_OR_RETURN(
+ const LogicalBuffer* buffer,
+ constraints->points_to_analysis().GetBufferDefinedAt(user, /*index=*/{}));
+
+ if (constraints->BufferLayout(*buffer) == nullptr) {
+ std::unique_ptr<Layout> layout = ChooseOutputLayoutFromOperandLayout(
+ operand_constraint.shape_layout().layout(), user,
+ operand_constraint.operand_no());
+ if (layout != nullptr) {
+ TF_RETURN_IF_ERROR(constraints->SetBufferLayout(*layout, *buffer));
+ }
+ }
+ return Status::OK();
+}
+
+Status LayoutAssignment::PropagateBufferConstraint(
+ const BufferLayoutConstraint& buffer_constraint,
+ LayoutConstraints* constraints) {
+ // Only propagate array layouts.
+ const LogicalBuffer& buffer = buffer_constraint.buffer();
+ if (!buffer.IsArray()) {
+ return Status::OK();
+ }
+
+ // If this buffer is the result of an array-shaped op (as opposed to an array
+ // element in a tuple) try to propagate the layout to its operands.
+ if (buffer.IsTopLevel()) {
+ const HloInstruction* instruction = buffer.instruction();
+ // Propagate the def-constraint on an instruction to the use-constraints on
+ // its operands (use-def propagation).
+ for (int64 operand_no = 0; operand_no < instruction->operand_count();
+ ++operand_no) {
+ if (constraints->OperandLayout(instruction, operand_no) == nullptr &&
+ ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) {
+ std::unique_ptr<Layout> operand_layout =
+ ChooseOperandLayoutFromOutputLayout(buffer_constraint.layout(),
+ instruction, operand_no);
+ if (operand_layout != nullptr) {
+ TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
+ *operand_layout, instruction, operand_no));
+ }
+ }
+ }
+ }
+
+ // Propagate the layout to all array uses of the logical buffer. This skips
+ // uses of the buffer where the buffer is the element of a tuple.
+ for (const auto& user_operand_no :
+ GetArrayUsesOfBuffer(buffer, constraints->points_to_analysis())) {
+ const HloInstruction* user = user_operand_no.first;
+ int64 operand_no = user_operand_no.second;
+ // Only add an operand constraint if the user does not forward the buffer
+ // because this case is not handled is SetOperandLayout.
+ if (constraints->OperandLayout(user, operand_no) == nullptr &&
+ !constraints->OperandBufferForwarded(user, operand_no)) {
+ TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
+ buffer_constraint.layout(), user, operand_no));
+ }
+ }
+
+ return Status::OK();
+}
+
+Status LayoutAssignment::PropagateResultConstraint(
+ const ResultLayoutConstraint& result_constraint,
+ LayoutConstraints* constraints) {
+ // Propagate the use constraint of the root instruction up to the logical
+ // buffers which make up the result.
+ return PropagateUseConstraintToDefs(
+ result_constraint.shape_layout(),
+ constraints->computation()->root_instruction(), constraints);
+}
+
+namespace {
+
+// Infers the layout of the array at the given index in the given instruction's
+// output using points-to analysis. Precondition: The given instruction must
+// not produce this array value (that is, the array is forwarded from the
+// instruction's operands).
+StatusOr<Layout> InferArrayLayout(
+ const TuplePointsToAnalysis& points_to_analysis,
+ HloInstruction* instruction, const ShapeIndex& index) {
+ // This function should only be called for array shapes which don't yet have
+ // layouts.
+ const Shape& subshape = ShapeUtil::GetSubshape(instruction->shape(), index);
+ TF_RET_CHECK(ShapeUtil::IsArray(subshape));
+ TF_RET_CHECK(!subshape.has_layout());
+
+ // The instruction should not define the buffer at this index.
+ TF_RET_CHECK(
+ !points_to_analysis.InstructionDefinesBufferAtIndex(instruction, index));
+
+ const std::vector<const LogicalBuffer*>& source_buffers =
+ points_to_analysis.GetPointsToSet(instruction).element(index);
+ TF_RET_CHECK(!source_buffers.empty());
+
+ // Verify the layout is the same for every LogicalBuffer which this location
+ // ('instruction' and 'index') points to.
+ const Layout* first_buffer_layout = nullptr;
+ for (const LogicalBuffer* source_buffer : source_buffers) {
+ if (!source_buffer->shape().has_layout()) {
+ // This should not happen because we've assigned layouts to all
+ // instructions preceding this one.
+ return InternalError("LogicalBuffer %s does not have a layout",
+ source_buffer->ToString().c_str());
+ }
+
+ if (first_buffer_layout == nullptr) {
+ first_buffer_layout = &source_buffer->shape().layout();
+ } else if (!LayoutUtil::Equal(source_buffer->shape().layout(),
+ *first_buffer_layout)) {
+ // The points-to set is ambiguous for this index and the different source
+ // buffers have different layouts. This case is possible in valid XLA
+ // computations because we do not propagate BufferLayoutConstaints to all
+ // LogicalBuffers which may alias the constrained LogicalBuffer at some
+ // point in the computation.
+ return FailedPrecondition(
+ "Array at index {%s} in instruction %s aliases buffers %s "
+ "and %s which have different layouts",
+ tensorflow::str_util::Join(index, ",").c_str(),
+ instruction->name().c_str(), source_buffers[0]->ToString().c_str(),
+ source_buffer->ToString().c_str());
+ }
+ }
+
+ return *first_buffer_layout;
+}
+
+// Creates and returns a copy of the given instruction with a different
+// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
+// instruction producing the copy is returned.
+StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ const Shape& shape_with_layout, HloInstruction* instruction) {
+ TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout));
+ DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape()));
+
+ if (ShapeUtil::IsTuple(instruction->shape())) {
+ // Deep-copy tuples.
+ std::vector<HloInstruction*> element_copies;
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
+ ++i) {
+ HloInstruction* gte = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
+ i));
+
+ // Recurse to copy each elements.
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * element_copy,
+ CreateCopyWithNewLayout(
+ ShapeUtil::GetSubshape(shape_with_layout, {i}), gte));
+ element_copies.push_back(element_copy);
+ }
+ // Gather element copies into a tuple with a new Tuple instruction.
+ HloInstruction* tuple_copy = instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(element_copies));
+ LayoutUtil::ClearLayout(tuple_copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, tuple_copy->mutable_shape()));
+ return tuple_copy;
+ } else if (ShapeUtil::IsArray(instruction->shape())) {
+ HloInstruction* copy =
+ instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction->shape(), HloOpcode::kCopy, instruction));
+ LayoutUtil::ClearLayout(copy->mutable_shape());
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ shape_with_layout, copy->mutable_shape()));
+
+ return copy;
+ } else {
+ return FailedPrecondition(
+ "Can only copy array and tuple shaped instructions");
+ }
+}
+
+// Creates a copy of the given operand if the operand's layout does not match
+// the given layout. This copy replaces the use in the given instruction. Tuple
+// operands will be deep-copied.
+Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no) {
+ HloInstruction* operand = instruction->mutable_operand(operand_no);
+ TF_RET_CHECK(operand_layout.LayoutIsSet());
+ TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
+
+ if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
+ // Operand layout already matches our constraint. Nothing to do.
+ return Status::OK();
+ }
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
+ CreateCopyWithNewLayout(operand_layout.shape(), operand));
+
+ instruction->ReplaceOperandWith(operand_no, operand_copy);
+ return Status::OK();
+}
+
+// For fusion instructions, set the layout of each fused parameter instruction
+// to match the layout of its corresponding fusion instruction operand. Also,
+// set the layout of the fused root to match the layout of the fusion
+// instruction itself.
+// Fused GetTupleElement requires a layout so that TBAA metadata for the tuple
+// element array pointer load can be added.
+Status SetFusionLayouts(HloInstruction* fusion) {
+ TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion);
+ for (auto& fused_instruction : fusion->fused_instructions()) {
+ if (fused_instruction->opcode() == HloOpcode::kParameter) {
+ const HloInstruction* fusion_operand =
+ fusion->operand(fused_instruction->parameter_number());
+ DCHECK(ShapeUtil::Compatible(fusion_operand->shape(),
+ fused_instruction->shape()));
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ fusion_operand->shape(), fused_instruction->mutable_shape()));
+ } else if (fused_instruction.get() == fusion->fused_expression_root()) {
+ // The layout of the root of the fused expression must match the fusion
+ // instruction layout.
+ DCHECK(
+ ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape()));
+ TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
+ fusion->shape(), fused_instruction->mutable_shape()));
+ } else if (fused_instruction->opcode() != HloOpcode::kConstant &&
+ fused_instruction->opcode() != HloOpcode::kGetTupleElement &&
+ fused_instruction->opcode() != HloOpcode::kInfeed) {
+ // Internal fused instructions with the exception of constants
+ // and infeed need no layout.
+ LayoutUtil::ClearLayout(fused_instruction->mutable_shape());
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
+ HloComputation* computation) {
+ VLOG(2) << "Assigning layouts to computation: " << computation->name();
+ XLA_VLOG_LINES(2, computation->ToString());
+ XLA_VLOG_LINES(2, constraints.ToString());
+
+ for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
+ LayoutUtil::ClearLayout(instruction->mutable_shape());
+
+ // Create a copy of an operand if the operand instruction's layout does not
+ // match the use constraint (OperandLayoutConstraint).
+ for (int64 operand_no = 0; operand_no < instruction->operand_count();
+ ++operand_no) {
+ const ShapeLayout* operand_layout =
+ constraints.OperandLayout(instruction, operand_no);
+ if (operand_layout != nullptr) {
+ TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
+ instruction, operand_no));
+ }
+ }
+
+ // Set the layouts of the array shapes this instruction defines as
+ // indicated by the respective BufferLayoutConstraints. Any array shapes
+ // in the output of the instruction which are not defined by the instruction
+ // (eg, array elements in a Tuple instruction) will be assigned below via
+ // inference.
+ for (const LogicalBuffer* buffer :
+ constraints.points_to_analysis().GetBuffersDefinedByInstruction(
+ instruction)) {
+ if (!ShapeUtil::IsArray(buffer->shape())) {
+ continue;
+ }
+
+ TF_RET_CHECK(buffer->instruction() == instruction);
+ Shape* buffer_subshape = ShapeUtil::GetMutableSubshape(
+ instruction->mutable_shape(), buffer->index());
+ const Layout* buffer_layout = constraints.BufferLayout(*buffer);
+ TF_RET_CHECK(buffer_layout != nullptr);
+ *buffer_subshape->mutable_layout() = *buffer_layout;
+ }
+
+ // Any remaining layouts in the output of the instruction must be
+ // inferrable using points-to analysis.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshape(
+ instruction->mutable_shape(),
+ [instruction, &constraints](Shape* subshape, const ShapeIndex& index) {
+ if (subshape->has_layout() || !ShapeUtil::IsArray(*subshape)) {
+ return Status::OK();
+ }
+ // Set Layout of subshape to match layout of LogicalBuffer which
+ // produces it.
+ TF_ASSIGN_OR_RETURN(*subshape->mutable_layout(),
+ InferArrayLayout(constraints.points_to_analysis(),
+ instruction, index));
+ return Status::OK();
+ }));
+
+ // Fusion instructions require some layouts to be set on fused instructions
+ // inside the fusion instruction.
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(SetFusionLayouts(instruction));
+ }
+
+ // Verify all layouts in the shape have been set.
+ TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
+ }
+
+ // Copy the root instrucion's result if the it does not match the result
+ // layout constraint
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
+
+ return Status::OK();
+}
+
+Status LayoutAssignment::RunOnComputation(
+ const ComputationLayout& computation_layout, HloComputation* computation) {
+ DCHECK(computation_layout.LayoutIsSet());
+ InsertOrDie(&computation_layouts_, computation, computation_layout);
+ VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
+ << ")";
+ VLOG(2) << " ComputationLayout = " << computation_layout.ToString();
+
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(computation->parent()));
+
+ // Construct LayoutConstaints with all layout constraints of the computation.
+ LayoutConstraints constraints(*points_to_analysis, computation);
+
+ // Add constraints required for correctness on all backends (eg, entry
+ // parameter layout constraints).
+ TF_RETURN_IF_ERROR(
+ AddMandatoryConstraints(computation_layout, computation, &constraints));
+
+ // Add any backend-specific constraints.
+ TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints));
+
+ // Propagates layouts from an HLO to its neighbors.
+ TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
+
+ // While any unconstrained buffers remain, pick an arbitrary buffer, give it a
+ // layout and propagate the change.
+ while (!constraints.unconstrained_buffer_ids().empty()) {
+ int unconstrained_count = constraints.unconstrained_buffer_ids().size();
+
+ // Arbitrarily pick the first unconstrained buffer and give it the default
+ // layout. By construction unconstrained_buffers() has a stable sort based
+ // on LogicalBuffer::Id.
+ const LogicalBuffer& buffer = points_to_analysis->GetBuffer(
+ *constraints.unconstrained_buffer_ids().begin());
+ TF_RETURN_IF_ERROR(constraints.SetBufferLayout(
+ LayoutUtil::GetDefaultLayoutForShape(buffer.shape()), buffer));
+
+ TF_RETURN_IF_ERROR(PropagateConstraints(&constraints));
+
+ // To verify progress has been made, check that the number of unconstrained
+ // buffers has been reduced.
+ CHECK_LT(constraints.unconstrained_buffer_ids().size(),
+ unconstrained_count);
+ }
+
+ // All logical buffers should have constraints at this point. All that
+ // remains is assign the constraints to the buffers and infer layouts for
+ // aliased buffers.
+ return AssignLayouts(constraints, computation);
+}
+
+StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
+ VLOG(2) << "Running layout assignment on module " << module->name();
+ XLA_VLOG_LINES(3, module->ToString());
+ if (VLOG_IS_ON(10)) {
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
+ "before layout assignment",
+ /*show_addresses=*/false,
+ /*show_layouts=*/true);
+ }
+
+ // Assign layouts to computations in an order such that a callee computation
+ // is handled before its caller computation. This ensures that the layout of
+ // all callers of a computation will agree.
+ for (auto* computation : module->MakeComputationPostOrder()) {
+ if (computation == module->entry_computation()) {
+ TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_,
+ module->entry_computation()));
+ } else {
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ // Setting all embedded computations to the default layout is potentially
+ // suboptimal.
+ computation_layout.SetToDefaultLayout();
+ TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, computation));
+ }
+ }
+
+ TF_RETURN_IF_ERROR(CheckLayouts(module, computation_layouts_));
+
+ VLOG(3) << "After layout assignment:";
+ XLA_VLOG_LINES(3, module->ToString());
+ if (VLOG_IS_ON(10)) {
+ hlo_graph_dumper::DumpGraph(*module->entry_computation(),
+ "after layout assignment",
+ /*show_addresses=*/false,
+ /*show_layouts=*/true);
+ }
+
+ // All layouts are reset then reassigned by this pass.
+ return true;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
new file mode 100644
index 0000000000..3c67a95412
--- /dev/null
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -0,0 +1,302 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
+
+#include <iosfwd>
+#include <map>
+#include <memory>
+#include <set>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Abstract base class for layout constraints. These constraint objects are
+// gathered together in LayoutConstraints object.
+class LayoutConstraint {
+ public:
+ LayoutConstraint() = default;
+ virtual ~LayoutConstraint() = default;
+
+ virtual string ToString() const = 0;
+};
+
+std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint);
+
+// Layout constraint on a single LogicalBuffer. This constrains the layout of an
+// array produced by a particular instruction.
+class BufferLayoutConstraint : public LayoutConstraint {
+ public:
+ BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer);
+
+ const LogicalBuffer& buffer() const { return *buffer_; }
+ const Layout& layout() const { return layout_; }
+
+ string ToString() const override;
+
+ private:
+ const Layout layout_;
+ const LogicalBuffer* buffer_;
+};
+
+// Constraint on the layout of the operand of an instruction. The constrained
+// shape can be arbitrarily shaped (array or tuple). This is a constraint on the
+// use of a shaped value and is not a hard constraint on the instruction(s)
+// which define the value as copies may be inserted between the definition and
+// use.
+class OperandLayoutConstraint : public LayoutConstraint {
+ public:
+ OperandLayoutConstraint(const ShapeLayout& shape_layout,
+ const HloInstruction* instruction, int64 operand_no);
+
+ const ShapeLayout& shape_layout() const { return shape_layout_; }
+ const HloInstruction* instruction() const { return instruction_; }
+ const int64 operand_no() const { return operand_no_; }
+ const HloInstruction* operand() const {
+ return instruction_->operand(operand_no_);
+ }
+
+ string ToString() const override;
+
+ private:
+ const ShapeLayout shape_layout_;
+ const HloInstruction* instruction_;
+ int64 operand_no_;
+};
+
+// Constraint on the layout of the result of the entry computation.
+class ResultLayoutConstraint : public LayoutConstraint {
+ public:
+ explicit ResultLayoutConstraint(const ShapeLayout& shape_layout)
+ : shape_layout_(shape_layout) {}
+
+ const ShapeLayout& shape_layout() const { return shape_layout_; }
+ string ToString() const override;
+
+ private:
+ const ShapeLayout shape_layout_;
+};
+
+// Class encapsulating the layout constraints of the values in a HLO
+// computation.
+class LayoutConstraints {
+ public:
+ LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis,
+ const HloComputation* computation);
+ ~LayoutConstraints() = default;
+
+ const HloComputation* computation() const { return computation_; }
+ const TuplePointsToAnalysis& points_to_analysis() const {
+ return points_to_analysis_;
+ }
+
+ // Return a vector containing the constraints which have been added to the
+ // LayoutConstraints object since the construction of the object or since the
+ // last time ConsumeAddedConstraints() has been called. This is used to
+ // identify
+ // newly added constraints when propagating layouts.
+ std::vector<const LayoutConstraint*> ConsumeAddedConstraints() {
+ std::vector<const LayoutConstraint*> ret_vec(std::move(added_constraints_));
+ added_constraints_.clear();
+ return ret_vec;
+ }
+ void ClearAddedConstraints() { added_constraints_.clear(); }
+
+ // Returns the layout of a LogicalBuffer, the layout of the operand of the
+ // instruction, or the layout of the result of the computation, respectively,
+ // if it has been constrained. Otherwise return nullptr.
+ const Layout* BufferLayout(const LogicalBuffer& buffer) const;
+ const ShapeLayout* OperandLayout(const HloInstruction* instruction,
+ int64 operand_no) const;
+ const ShapeLayout* ResultLayout() const;
+
+ // Add a constraint on the layout of a LogicalBuffer, the layout of the
+ // operand of the instruction, or the layout of the result of the computation,
+ // respectively.
+ Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer);
+ Status SetOperandLayout(const Shape& shape_with_layout,
+ const HloInstruction* instruction, int64 operand_no);
+ Status SetResultLayout(const Shape& shape_with_layout);
+
+ // Convenience wrapper around SetOperandLayout for setting the layout of a
+ // operand using a Layout object. The operand must be array-shaped.
+ Status SetArrayOperandLayout(const Layout& layout,
+ const HloInstruction* instruction,
+ int64 operand_no);
+
+ // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers
+ // created by the instruction to the layouts in the given shape. The
+ // instruction must define every logical buffer in its output.
+ Status SetInstructionLayout(const Shape& shape_with_layout,
+ const HloInstruction* instruction);
+
+ // Returns true if any buffer in the given operand is forwarded to the output
+ // of the given instruction. For example, the Tuple instruction forwards the
+ // buffers of its operands and would return true for each of its operands.
+ bool OperandBufferForwarded(const HloInstruction* instruction,
+ int64 operand_no) const;
+
+ // Returns the set of logical buffers (by LogicalBuffer:Id) which do not
+ // yet have a layout constraint
+ const std::set<LogicalBuffer::Id>& unconstrained_buffer_ids() const {
+ return unconstrained_buffer_ids_;
+ }
+
+ string ToString() const;
+
+ private:
+ // The set of BufferLayoutConstraints applied to the computation.
+ std::unordered_map<const LogicalBuffer*, BufferLayoutConstraint>
+ buffer_constraints_;
+
+ // The set of OperandLayoutConstraints applied to the computation.
+ using OperandConstraintKey = std::pair<const HloInstruction*, int64>;
+ std::map<OperandConstraintKey, OperandLayoutConstraint> operand_constraints_;
+
+ // The result constraint for the computation (can be null).
+ std::unique_ptr<ResultLayoutConstraint> result_constraint_;
+
+ // A vector which holds constraints as they are added. Can be cleared with
+ // ClearAddedConstraints.
+ std::vector<const LayoutConstraint*> added_constraints_;
+
+ // Points-to analysis for the module. Used to propagate constraints through
+ // the HLO graph.
+ const TuplePointsToAnalysis& points_to_analysis_;
+
+ // Array-shaped buffers which have not yet been constrained.
+ std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
+
+ const HloComputation* computation_;
+};
+
+// HLO pass which assigns layouts to all instructions in the HLO module while
+// satisfying all necessary invariants and minimizing cost.
+class LayoutAssignment : public HloPass {
+ public:
+ // entry_computation_layout is modified to populate a layout for the result in
+ // the case that no particular layout is requested.
+ explicit LayoutAssignment(ComputationLayout* entry_computation_layout);
+ ~LayoutAssignment() override {}
+
+ // Assign layouts to the given module. Returns whether the module was changed
+ // (any layouts were changed).
+ StatusOr<bool> Run(HloModule* module) override;
+
+ protected:
+ // These methods, invoked by PropagateConstraints, propagate a layout
+ // constraint to its neighbors (i.e. operands and users) in order to minimize
+ // the cost of the instructions being constrainted on. New constraints are
+ // added to the given constraint set.
+ //
+ // Backends can override these methods with backend-specific propagation
+ // rules.
+ virtual Status PropagateBufferConstraint(
+ const BufferLayoutConstraint& layout_constraint,
+ LayoutConstraints* constraints);
+ virtual Status PropagateOperandConstraint(
+ const OperandLayoutConstraint& layout_constraint,
+ LayoutConstraints* constraints);
+ virtual Status PropagateResultConstraint(
+ const ResultLayoutConstraint& layout_constraint,
+ LayoutConstraints* constraints);
+
+ private:
+ // Adds constraints which must be satisfied for correctness on all
+ // backends. Called once prior to propagating constraints.
+ Status AddMandatoryConstraints(const ComputationLayout& computation_layout,
+ HloComputation* computation,
+ LayoutConstraints* constraints);
+
+ // This method can be overridden to add backend-specific constraints to the
+ // layout of the instructions of a computation. This method is called after
+ // all mandatory constraints have been added via AddMandatoryConstraints
+ // and before propagating constraints.
+ virtual Status AddBackendConstraints(LayoutConstraints* constraints) {
+ return Status::OK();
+ }
+
+ // Construct contraints and assign layouts to all instructions in the
+ // computation satisfying the given ComputationLayout. Layouts constraints are
+ // added, then propagated until all LogicalBuffers in the computation are
+ // constrained.
+ Status RunOnComputation(const ComputationLayout& computation_layout,
+ HloComputation* computation);
+
+ // Assign layouts to the instructions of a computation which satisfy the given
+ // layout constraints. Copies may be added to satisfy the constraints. The
+ // given LayoutConstraints must have layout constraints every logical buffer
+ // in the computation.
+ Status AssignLayouts(const LayoutConstraints& constraints,
+ HloComputation* computation);
+
+ // Propagates layout constraints from a set of initial constraints in order to
+ // minimize the local cost of the computation. This propagation is *not*
+ // required for correctness.
+ Status PropagateConstraints(LayoutConstraints* constraints);
+
+ // Propagates a layout constraint on the use of the result of the given
+ // instruction to the definitions of the LogicalBuffers which make up the
+ // result.
+ Status PropagateUseConstraintToDefs(const ShapeLayout& shape_layout,
+ const HloInstruction* instruction,
+ LayoutConstraints* constraints);
+
+ // Chooses a layout of operand `operand_no` of `instruction` that minimizes
+ // the cost of `instruction`. `output_layout` is the layout of `instruction`.
+ // Returns null if it can't decide the best layout.
+ // Precondition: `instruction` and the operand are array-shaped.
+ std::unique_ptr<Layout> ChooseOperandLayoutFromOutputLayout(
+ const Layout& output_layout, const HloInstruction* instruction,
+ int64 operand_no);
+ // Given the layout of `user`'s `operand_no`-th operand, chooses a layout of
+ // `user` that minimizes its cost on that operand. Returns null if it can't
+ // decide the best layout.
+ // Precondition: `user` and the operand are array-shaped.
+ std::unique_ptr<Layout> ChooseOutputLayoutFromOperandLayout(
+ const Layout& operand_layout, const HloInstruction* user,
+ int64 operand_no);
+
+ ComputationLayout* entry_computation_layout_;
+
+ // Map containing the layouts of all computations assigned so
+ // far. Computations are handled in a topological sort where computations are
+ // handled before their caller instructions so the layouts of caller
+ // instructions can be set to match the computation.
+ std::map<HloComputation*, ComputationLayout> computation_layouts_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LAYOUT_ASSIGNMENT_H_
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
new file mode 100644
index 0000000000..6361907b0e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -0,0 +1,486 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
+
+#include <initializer_list>
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+namespace {
+
+class LayoutAssignmentTest : public HloTestBase {
+ protected:
+ void AssignLayouts(HloModule* module,
+ ComputationLayout* entry_computation_layout) {
+ LayoutAssignment layout_assignment(entry_computation_layout);
+ EXPECT_IS_OK(layout_assignment.Run(module).status());
+ }
+};
+
+TEST_F(LayoutAssignmentTest, ComputationLayout) {
+ // Verify the layouts of the root and parameter instructions of a computation
+ // match the ComputationLayout for two different layouts.
+ std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
+ for (auto& minor_to_major : minor_to_majors) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ashape, "param1"));
+ auto add = builder.AddInstruction(
+ HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ Layout layout = LayoutUtil::MakeLayout(minor_to_major);
+ Shape shape(ashape);
+ *shape.mutable_layout() = layout;
+ const ShapeLayout shape_layout(shape);
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) = shape_layout;
+ *computation_layout.mutable_parameter_layout(1) = shape_layout;
+ *computation_layout.mutable_result_layout() = shape_layout;
+ AssignLayouts(&module, &computation_layout);
+ EXPECT_TRUE(LayoutUtil::Equal(layout, param0->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(layout, param1->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(layout, add->shape().layout()));
+ }
+}
+
+TEST_F(LayoutAssignmentTest, ComputationLayoutMixedLayout) {
+ // Verify the layouts of the root and parameter instructions of a computation
+ // match the ComputationLayout which has mixed layout.
+ auto builder = HloComputation::Builder(TestName());
+ Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ashape, "param1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, param0, param1));
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ Layout col_major_layout = LayoutUtil::MakeLayout({1, 0});
+ Shape col_major_shape(ashape);
+ *col_major_shape.mutable_layout() = col_major_layout;
+ const ShapeLayout col_major(col_major_shape);
+
+ Layout row_major_layout = LayoutUtil::MakeLayout({0, 1});
+ Shape row_major_shape(ashape);
+ *row_major_shape.mutable_layout() = row_major_layout;
+ const ShapeLayout row_major(row_major_shape);
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) = col_major;
+ *computation_layout.mutable_parameter_layout(1) = row_major;
+ *computation_layout.mutable_result_layout() = col_major;
+
+ AssignLayouts(&module, &computation_layout);
+ EXPECT_TRUE(LayoutUtil::Equal(col_major_layout, param0->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(row_major_layout, param1->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(
+ col_major_layout, computation->root_instruction()->shape().layout()));
+}
+
+TEST_F(LayoutAssignmentTest, FusionInstruction) {
+ // Verify that the layout of the fused parameters in a fusion instruction
+ // match that of the fusion operands. Other fused instructions should have no
+ // layout.
+ std::vector<std::initializer_list<int64>> minor_to_majors = {{0, 1}, {1, 0}};
+ for (auto& minor_to_major : minor_to_majors) {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant_literal1 = test_utils::CreateR2LiteralWithLayout<float>(
+ {{1.0, 2.0}, {3.0, 4.0}}, minor_to_major);
+ auto constant_literal2 = test_utils::CreateR2LiteralWithLayout<float>(
+ {{5.0, 6.0}, {7.0, 8.0}}, minor_to_major);
+ Shape ashape = constant_literal1->shape();
+
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(constant_literal1)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(constant_literal2)));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ ashape, HloOpcode::kAdd, constant1, constant2));
+ auto negate1 = builder.AddInstruction(
+ HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, add));
+ auto negate2 = builder.AddInstruction(
+ HloInstruction::CreateUnary(ashape, HloOpcode::kNegate, negate1));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build());
+
+ auto fusion = computation->CreateFusionInstruction(
+ {negate2, negate1, add}, HloInstruction::FusionKind::kLoop);
+
+ Layout layout = LayoutUtil::MakeLayout(minor_to_major);
+ Shape shape(ashape);
+ *shape.mutable_layout() = layout;
+ const ShapeLayout shape_layout(shape);
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ *computation_layout.mutable_result_layout() = shape_layout;
+
+ AssignLayouts(&module, &computation_layout);
+
+ EXPECT_TRUE(LayoutUtil::Equal(
+ layout, fusion->fused_parameter(0)->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(
+ layout, fusion->fused_parameter(1)->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(
+ layout, fusion->fused_expression_root()->shape().layout()));
+
+ // Inner fused node should not have layout.
+ EXPECT_FALSE(LayoutUtil::HasLayout(
+ fusion->fused_expression_root()->operand(0)->shape()));
+ }
+}
+
+TEST_F(LayoutAssignmentTest, TupleLayout) {
+ // Verify the layouts of a tuple are assigned properly (the element layouts
+ // match their source).
+ auto builder = HloComputation::Builder(TestName());
+ auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ {0, 1})));
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ {1, 0})));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant0, constant1}));
+
+ // To avoid having to construct a tuple layout in the ComputationLayout below,
+ // make the result of the instruction be an array.
+ auto get_element0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(constant0->shape(), tuple, 0));
+ auto negate = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant0->shape(), HloOpcode::kNegate, get_element0));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ ComputationLayout computation_layout(
+ module.entry_computation()->ComputeProgramShape());
+
+ AssignLayouts(&module, &computation_layout);
+
+ EXPECT_FALSE(
+ LayoutUtil::LayoutsInShapesEqual(constant0->shape(), constant1->shape()));
+
+ EXPECT_TRUE(LayoutUtil::HasLayout(tuple->shape()));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
+ negate->shape(), computation_layout.result_layout().shape()));
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(
+ ShapeUtil::GetTupleElementShape(tuple->shape(), 1), constant1->shape()));
+}
+
+TEST_F(LayoutAssignmentTest, TupleSelect) {
+ // Verify layouts of a select with tuple operands is assigned properly.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ {0, 1})));
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ {1, 0})));
+ auto tuple0 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant0, constant1}));
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant0, constant1}));
+
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
+
+ auto select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ ComputationLayout computation_layout(
+ module.entry_computation()->ComputeProgramShape());
+ Shape result_shape =
+ ShapeUtil::MakeTupleShape({constant0->shape(), constant1->shape()});
+ TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
+ result_shape));
+
+ AssignLayouts(&module, &computation_layout);
+
+ EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(result_shape, select->shape()));
+}
+
+TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) {
+ // Construct following computation which has conflicting layouts for two
+ // elements of a tuple which share the same source logicalb buffer:
+ //
+ // %constant = Constant(...)
+ // %inner_tuple = Tuple(%constant)
+ // %nested_tuple = Tuple(%inner_tuple, %inner_tuple)
+ //
+ // Result layout col-major for the first element and row-major for the
+ // second. This results in the conflict where the element of the inner_tuple
+ // needs to be both col and row major. This is resolved by deep-copying the
+ // tuple and assigning the layouts of the copied arrays as needed.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ auto inner_tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({constant}));
+ auto nested_tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({inner_tuple, inner_tuple}));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ ComputationLayout computation_layout(
+ module.entry_computation()->ComputeProgramShape());
+ Shape result_shape = nested_tuple->shape();
+ *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{0, 0}) =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
+ *ShapeUtil::GetMutableSubshape(&result_shape, /*index=*/{1, 0}) =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1});
+ TF_CHECK_OK(computation_layout.mutable_result_layout()->CopyLayoutFromShape(
+ result_shape));
+
+ LayoutAssignment layout_assignment(&computation_layout);
+ AssignLayouts(&module, &computation_layout);
+
+ // Layout assignment should have deep copied the result of the computation to
+ // address the layout conflict. This results in several Tuple() and
+ // GetTupleElement() instructions. Running algebraic simplification should
+ // clean up the code to something like:
+ //
+ // %constant = Constant(...) layout={1,0}
+ // %tuple.0 = Tuple(%constant) layout=({1,0})
+ // %copy = Copy(%constant) layout={0,1} # layout transposed
+ // %tuple.1 = Tuple(%copy) layout=({0,1})
+ // %tuple.2 = Tuple(%tuple.0, %tuple.1) layout=(({1,0}), ({0,1}))
+ //
+ EXPECT_TRUE(
+ AlgebraicSimplifier(/*is_layout_sensitive=*/true,
+ [](const Shape&, const Shape&) { return false; })
+ .Run(&module)
+ .ValueOrDie());
+ HloInstruction* root = module.entry_computation()->root_instruction();
+ // Verify layout of the root and the root's operands.
+ EXPECT_TRUE(ShapeUtil::Equal(result_shape, root->shape()));
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {0}),
+ root->operand(0)->shape()));
+ EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::GetSubshape(result_shape, {1}),
+ root->operand(1)->shape()));
+
+ // Verify some of the structure of the HLO graph.
+ EXPECT_EQ(constant, root->operand(0)->operand(0));
+ EXPECT_EQ(HloOpcode::kCopy, root->operand(1)->operand(0)->opcode());
+ EXPECT_EQ(HloOpcode::kConstant,
+ root->operand(1)->operand(0)->operand(0)->opcode());
+}
+
+TEST_F(LayoutAssignmentTest, ElementwiseAndReshape) {
+ // param -> log -> reshape -> tanh
+ auto builder = HloComputation::Builder(TestName());
+ Shape ashape = ShapeUtil::MakeShape(F32, {1, 2, 3, 1});
+ Shape bshape = ShapeUtil::MakeShape(F32, {2, 1, 3});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "param"));
+ auto log = builder.AddInstruction(
+ HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
+ auto reshape =
+ builder.AddInstruction(HloInstruction::CreateReshape(bshape, log));
+ auto tanh = builder.AddInstruction(
+ HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, reshape));
+
+ HloModule module(TestName());
+ HloComputation* computation = module.AddEntryComputation(builder.Build(tanh));
+
+ Shape ashape_with_layout(ashape);
+ Shape bshape_with_layout(bshape);
+ *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3});
+ *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2});
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ashape_with_layout);
+ *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
+ AssignLayouts(&module, &computation_layout);
+
+ auto log_minor_to_major =
+ AsInt64Slice(log->shape().layout().minor_to_major());
+ EXPECT_LT(PositionInContainer(log_minor_to_major, 1),
+ PositionInContainer(log_minor_to_major, 2));
+
+ auto reshape_minor_to_major =
+ AsInt64Slice(reshape->shape().layout().minor_to_major());
+ EXPECT_LT(PositionInContainer(reshape_minor_to_major, 0),
+ PositionInContainer(reshape_minor_to_major, 2));
+}
+
+// Test whether LayoutAssignment assigns layouts to elementwise operations to
+// keep linear indices valid across them, and to transpositions to make them
+// bitcasts.
+TEST_F(LayoutAssignmentTest, ElementwiseAndTranspose) {
+ // param -> log -> transpose -> tanh
+ auto builder = HloComputation::Builder(TestName());
+ Shape ashape = ShapeUtil::MakeShape(F32, {42, 12});
+ Shape bshape = ShapeUtil::MakeShape(F32, {12, 42});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "param"));
+ auto log = builder.AddInstruction(
+ HloInstruction::CreateUnary(ashape, HloOpcode::kLog, param));
+ auto transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(bshape, log, {1, 0}));
+ auto tanh = builder.AddInstruction(
+ HloInstruction::CreateUnary(bshape, HloOpcode::kTanh, transpose));
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build(tanh));
+
+ Shape ashape_with_layout(ashape);
+ Shape bshape_with_layout(bshape);
+ *ashape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ *bshape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ashape_with_layout);
+ *computation_layout.mutable_result_layout() = ShapeLayout(bshape_with_layout);
+ AssignLayouts(&module, &computation_layout);
+
+ EXPECT_TRUE(
+ LayoutUtil::Equal(ashape_with_layout.layout(), log->shape().layout()));
+ EXPECT_TRUE(LayoutUtil::Equal(bshape_with_layout.layout(),
+ transpose->shape().layout()));
+ EXPECT_TRUE(
+ LayoutUtil::Equal(bshape_with_layout.layout(), tanh->shape().layout()));
+}
+
+// Test whether LayoutAssignment assigns layouts to transpositions to make them
+// bitcasts.
+TEST_F(LayoutAssignmentTest, BroadcastAndTranspose) {
+ // param -> broadcast -> transpose
+ auto builder = HloComputation::Builder(TestName());
+ Shape ashape = ShapeUtil::MakeShape(F32, {3, 4});
+ Shape bshape = ShapeUtil::MakeShape(F32, {2, 3, 4});
+ Shape cshape = ShapeUtil::MakeShape(F32, {4, 3, 2});
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ashape, "param"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(bshape, param, {1, 2}));
+ auto transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(cshape, broadcast, {2, 1, 0}));
+ HloModule module(TestName());
+ HloComputation* computation =
+ module.AddEntryComputation(builder.Build(transpose));
+
+ Shape input_shape_with_layout(ashape);
+ Shape output_shape_with_layout(cshape);
+ *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ *output_shape_with_layout.mutable_layout() =
+ LayoutUtil::MakeLayout({2, 1, 0});
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(input_shape_with_layout);
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(output_shape_with_layout);
+ AssignLayouts(&module, &computation_layout);
+
+ EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(),
+ tensorflow::gtl::ArraySlice<int64>{0, 1, 2}));
+}
+
+TEST_F(LayoutAssignmentTest, ReshapeOperandHasMultipleUsers) {
+ // param[4] -> broadcast[3x4] ------> transpose[4x3]-------- -------> tuple
+ // \ /
+ // \-> tanh[3x4] -> broadcast2[2x3x4] -/
+ //
+ // The layout of `transpose` is set to {1,0} because it provides a buffer to
+ // the computation result which has a fixed layout.. Therefore, `broadcast`
+ // (the operand of transpose) is expected to have layout {0,1} so that the
+ // transpose is a bitcast. Furthermore, `tanh` is expected to have the same
+ // layout as `broadcast` (i.e. {0,1}) because `tanh` is elementwise.
+ Shape f32_4 = ShapeUtil::MakeShape(F32, {4});
+ Shape f32_34 = ShapeUtil::MakeShape(F32, {3, 4});
+ Shape f32_43 = ShapeUtil::MakeShape(F32, {4, 3});
+ Shape f32_234 = ShapeUtil::MakeShape(F32, {2, 3, 4});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32_4, "param"));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32_34, param, {3}));
+ auto transpose = builder.AddInstruction(
+ HloInstruction::CreateTranspose(f32_43, broadcast, {1, 0}));
+ auto tanh = builder.AddInstruction(
+ HloInstruction::CreateUnary(f32_34, HloOpcode::kTanh, broadcast));
+ auto broadcast2 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32_234, tanh, {2}));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({transpose, broadcast2}));
+ HloModule module(TestName());
+ HloComputation* computation =
+ module.AddEntryComputation(builder.Build(tuple));
+
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ Shape param_shape_with_layout(f32_4);
+ Shape transpose_shape_with_layout(f32_43);
+ Shape broadcast2_shape_with_layout(f32_234);
+ *param_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({0});
+ *transpose_shape_with_layout.mutable_layout() =
+ LayoutUtil::MakeLayout({1, 0});
+ *broadcast2_shape_with_layout.mutable_layout() =
+ LayoutUtil::MakeLayout({2, 1, 0});
+
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(param_shape_with_layout);
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeTupleShape(
+ {transpose_shape_with_layout, broadcast2_shape_with_layout}));
+ AssignLayouts(&module, &computation_layout);
+
+ EXPECT_TRUE(ContainersEqual(broadcast->shape().layout().minor_to_major(),
+ tensorflow::gtl::ArraySlice<int64>{0, 1}));
+ EXPECT_TRUE(ContainersEqual(transpose->shape().layout().minor_to_major(),
+ tensorflow::gtl::ArraySlice<int64>{1, 0}));
+ EXPECT_TRUE(ContainersEqual(tanh->shape().layout().minor_to_major(),
+ tensorflow::gtl::ArraySlice<int64>{0, 1}));
+}
+
+// Add test which fails due to copy tuple.
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
new file mode 100644
index 0000000000..10468d9aae
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -0,0 +1,154 @@
+# Description:
+# Libraries for helping construct LLVM IR for XLA backends.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = [":friends"])
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+cc_library(
+ name = "alias_analysis",
+ srcs = ["alias_analysis.cc"],
+ hdrs = ["alias_analysis.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:logical_buffer",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "llvm_util",
+ srcs = ["llvm_util.cc"],
+ hdrs = ["llvm_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags",
+ "//tensorflow/compiler/xla/service:buffer_assignment",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ "@llvm//:support",
+ "@llvm//:target",
+ ],
+)
+
+cc_library(
+ name = "ir_array",
+ srcs = ["ir_array.cc"],
+ hdrs = ["ir_array.h"],
+ deps = [
+ ":llvm_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "llvm_loop",
+ srcs = ["llvm_loop.cc"],
+ hdrs = ["llvm_loop.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "loop_emitter",
+ srcs = ["loop_emitter.cc"],
+ hdrs = ["loop_emitter.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_loop",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "fused_ir_emitter",
+ srcs = ["fused_ir_emitter.cc"],
+ hdrs = ["fused_ir_emitter.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_util",
+ ":loop_emitter",
+ ":ops",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service:elemental_ir_emitter",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "ops",
+ srcs = ["ops.cc"],
+ hdrs = ["ops.h"],
+ deps = [
+ ":ir_array",
+ ":llvm_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/service/llvm_ir/README.md b/tensorflow/compiler/xla/service/llvm_ir/README.md
new file mode 100644
index 0000000000..9fe7152477
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/README.md
@@ -0,0 +1,2 @@
+Common utilites and abstractions for handling and emitting LLVM IR for XLA
+backends.
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
new file mode 100644
index 0000000000..a552ea0218
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -0,0 +1,195 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
+
+#include <unordered_set>
+
+#include "external/llvm/include/llvm/IR/MDBuilder.h"
+#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo,
+ llvm_ir::IrArray* array) {
+ BufferAllocation::Index buffer_index;
+ if (hlo.opcode() == HloOpcode::kParameter) {
+ // Parameters may alias with each other but may not alias with our temporary
+ // buffers.
+ buffer_index = kParameterAliasSet;
+ } else {
+ const std::set<BufferAllocation> allocations =
+ assignment_.GetAllocations(&hlo, /*index=*/{});
+ if (allocations.empty() || allocations.size() > 1) {
+ // Skip HLOs which don't have buffers a buffer assigned or for which the
+ // buffer can't be determined statically. We cannot determine their
+ // aliasing properties in these cases.
+ return;
+ }
+ buffer_index = allocations.begin()->index();
+ }
+
+ llvm::MDNode*& alias_scope_md = alias_scope_metadata_[buffer_index];
+ if (alias_scope_md == nullptr) {
+ alias_scope_md =
+ GetAliasScopeMetadataForBuffer(buffer_index, GetAliasDomain());
+ }
+ array->AddAliasScopeMetadata(alias_scope_md);
+
+ llvm::MDNode*& noalias_md = noalias_metadata_[buffer_index];
+ if (noalias_md == nullptr) {
+ noalias_md = GetNoaliasMetadataForBuffer(buffer_index, GetAliasDomain(),
+ assignment_, hlo);
+ }
+ array->AddNoaliasMetadata(noalias_md);
+
+ // Parameters of the entry computation are never stored to, loading from a
+ // parameter pointer should always return the same result within a loop.
+ if (hlo.opcode() == HloOpcode::kParameter) {
+ const std::vector<HloInstruction*>& parameter_instructions =
+ module_.entry_computation()->parameter_instructions();
+ if (std::find(parameter_instructions.begin(), parameter_instructions.end(),
+ &hlo) != parameter_instructions.end()) {
+ array->AddInvariantLoad(llvm::MDNode::get(*context_, /*MDs=*/{}));
+ }
+ }
+}
+
+llvm::MDNode* AliasAnalysis::GetAliasDomain() {
+ llvm::MDBuilder metadata_builder(*context_);
+ if (alias_domain_ == nullptr) {
+ alias_domain_ = metadata_builder.createAnonymousAliasScopeDomain();
+ }
+ return alias_domain_;
+}
+
+llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer(
+ BufferAllocation::Index buffer_index, llvm::MDNode* domain) {
+ legacy_flags::AliasAnalysisFlags* flags =
+ legacy_flags::GetAliasAnalysisFlags();
+ if (!flags->xla_emit_alias_scope) {
+ return nullptr;
+ }
+
+ // While we could synthesize an alias.scope, doing so is not more profitable
+ // than LLVM's default behavior.
+ if (buffer_index == kParameterAliasSet) {
+ return nullptr;
+ }
+
+ llvm::MDBuilder metadata_builder(domain->getContext());
+ llvm::MDNode* scope = metadata_builder.createAliasScope(
+ AsStringRef(tensorflow::strings::StrCat("buffer: ", buffer_index)),
+ domain);
+ llvm::MDNode* scope_list = llvm::MDNode::get(domain->getContext(), scope);
+ return scope_list;
+}
+
+llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
+ BufferAllocation::Index buffer_index, llvm::MDNode* domain,
+ const BufferAssignment& assignment, const HloInstruction& hlo) {
+ legacy_flags::AliasAnalysisFlags* flags =
+ legacy_flags::GetAliasAnalysisFlags();
+ if (!flags->xla_emit_alias_scope) {
+ return nullptr;
+ }
+
+ // We want to construct a list of buffers which:
+ //
+ // 1. Do not alias the given buffer.
+ // 2. Will plausibly be used in the vicinity of the given buffer.
+ //
+ // Making the noalias set overly large will result in either a massive
+ // slowdown in LLVM or LLVM will just ignore the noalias set.
+ //
+ // A plausible list of instructions are:
+ // 1. Users of the given hlo.
+ // 2. Operands of users of the given hlo.
+ // 3. Operands of the given hlo.
+ //
+ // This set can be increased as we need. For now only consider top-level
+ // buffers (index = {}) not buffers nested within the instruction's
+ // operands/output which are not typically touched.
+ std::vector<const LogicalBuffer*> worklist;
+ auto add_buffers_to_worklist =
+ [&worklist, &assignment](const HloInstruction* instruction) {
+ for (const LogicalBuffer* buffer :
+ assignment.GetSourceBuffers(instruction, /*index=*/{})) {
+ worklist.push_back(buffer);
+ }
+ };
+
+ for (HloInstruction* user : hlo.users()) {
+ add_buffers_to_worklist(user);
+ for (HloInstruction* operand : user->operands()) {
+ add_buffers_to_worklist(operand);
+ }
+ }
+
+ add_buffers_to_worklist(&hlo);
+ for (HloInstruction* operand : hlo.operands()) {
+ add_buffers_to_worklist(operand);
+ }
+
+ std::unordered_set<BufferAllocation::Index> buffers;
+ for (const LogicalBuffer* buffer : worklist) {
+ // Skip buffers which cannot be added to the noalias set.
+ if (!assignment.HasAllocation(*buffer) ||
+ buffer->instruction()->opcode() == HloOpcode::kParameter) {
+ continue;
+ }
+ BufferAllocation::Index noalias_index =
+ assignment.GetAssignedAllocation(*buffer).index();
+ // Our buffer must not noalias itself.
+ if (noalias_index != buffer_index) {
+ buffers.insert(noalias_index);
+ // Some instructions have too many operands, causing the noalias set to be
+ // too large. To reduce compilation time (b/31901575), truncate noalias
+ // sets to at most 500 elements.
+ //
+ // Future work: improvements to LLVM's scoped AA that avoid creating a
+ // MDNode set for every alias query can help to reduce the compilation
+ // time as well.
+ constexpr int kMaxNoAliasSetSize = 500;
+ if (buffers.size() >= kMaxNoAliasSetSize) {
+ break;
+ }
+ }
+ }
+
+ // Don't bother constructing a noalias metadata node if it would be empty.
+ if (buffers.empty()) {
+ return nullptr;
+ }
+
+ llvm::MDBuilder metadata_builder(domain->getContext());
+ std::vector<llvm::Metadata*> scopes;
+ for (BufferAllocation::Index noalias_index : buffers) {
+ llvm::MDNode* scope = metadata_builder.createAliasScope(
+ AsStringRef(tensorflow::strings::StrCat("buffer: ", noalias_index)),
+ domain);
+ scopes.push_back(scope);
+ }
+ llvm::MDNode* noalias_list =
+ llvm::MDNode::get(domain->getContext(), AsArrayRef(scopes));
+ return noalias_list;
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
new file mode 100644
index 0000000000..d8d45dd49b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+
+#include <unordered_map>
+
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// Helper functionality used to augment the LLVM IR emitted with alias-scope
+// metadata.
+class AliasAnalysis {
+ public:
+ AliasAnalysis(const HloModule& module, const BufferAssignment& assignment,
+ llvm::LLVMContext* context)
+ : module_(module), assignment_(assignment), context_(context) {}
+
+ // Augments IrArray with aliasing information.
+ void AddAliasingInformationToIrArray(const HloInstruction& hlo,
+ llvm_ir::IrArray* array);
+
+ private:
+ // Returns a unique alias domain for this emitter.
+ llvm::MDNode* GetAliasDomain();
+
+ // Returns an alias.scope metadata node corresponding to a given buffer index.
+ llvm::MDNode* GetAliasScopeMetadataForBuffer(
+ BufferAllocation::Index buffer_index, llvm::MDNode* domain);
+
+ // Returns a noalias metadata node corresponding to a given buffer index.
+ //
+ // |buffer_index| is the buffer index.
+ //
+ // |domain| corresponds to the alias scope domain as documented at
+ // http://llvm.org/docs/LangRef.html#noalias-and-alias-scope-metadata
+ //
+ // |hlo| is the instruction we are computing a noalias set for.
+ llvm::MDNode* GetNoaliasMetadataForBuffer(
+ BufferAllocation::Index buffer_index, llvm::MDNode* domain,
+ const BufferAssignment& assignment, const HloInstruction& hlo);
+
+ // The HLO module we are compiling for.
+ const HloModule& module_;
+
+ // Assignment of the temporary buffers needed by the computation and their
+ // shape information.
+ const BufferAssignment& assignment_;
+
+ // The LLVM context which we are using for IR emission.
+ llvm::LLVMContext* context_;
+
+ // Holds the alias domain for this computation.
+ llvm::MDNode* alias_domain_ = nullptr;
+
+ // Index in alias_scope_metadata_ and noalias_metadata_ for parameters
+ // of the entry computation which have special aliasing properties.
+ static constexpr int kParameterAliasSet = -1;
+
+ // A map from a buffer index to metadata corresponding to its alias.scope
+ // metadata. The index kParameterAliasSet is used to hold aliasing
+ // information for parameters.
+ std::unordered_map<int, llvm::MDNode*> alias_scope_metadata_;
+
+ // A map from a buffer index to metadata corresponding to its noalias
+ // metadata.
+ std::unordered_map<int, llvm::MDNode*> noalias_metadata_;
+};
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
new file mode 100644
index 0000000000..b259d34870
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc
@@ -0,0 +1,147 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
+
+#include <functional>
+
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+using llvm_ir::IrArray;
+
+Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
+ generators_[hlo] =
+ [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ if (generated_value_cache_[hlo].count(index.multidim()) > 0) {
+ llvm::Value* generated_value =
+ generated_value_cache_[hlo][index.multidim()];
+ llvm::BasicBlock* generated_value_bb = nullptr;
+ if (auto* generated_instruction =
+ llvm::dyn_cast<llvm::Instruction>(generated_value)) {
+ generated_value_bb = generated_instruction->getParent();
+ }
+ // Ideally, we should be able to reuse the cached generated value if it
+ // dominates the current insertion block. However, the check for dominance
+ // can be expensive and unreliable when the function is being constructed.
+ //
+ // It's also worth experimenting what if we don't do caching at all.
+ // LLVM's CSE or GVN should be able to easily merge common subexpressions
+ // that would be regenerated without caching. But this might increase the
+ // JIT compilation time.
+ if (generated_value_bb == nullptr ||
+ generated_value_bb == ir_builder_->GetInsertBlock()) {
+ VLOG(3) << "The cached generated value is reused.";
+ return generated_value;
+ }
+ VLOG(3) << "The cached generated value can't be reuse, because it is at "
+ "a different BB ("
+ << llvm_ir::AsString(generated_value_bb->getName())
+ << ") from the current insertion block ("
+ << llvm_ir::AsString(ir_builder_->GetInsertBlock()->getName())
+ << ").";
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ generated_value_cache_[hlo][index.multidim()],
+ elemental_emitter_->MakeElementGenerator(hlo, generators_)(index));
+ return generated_value_cache_[hlo][index.multidim()];
+ };
+ return Status::OK();
+}
+
+Status FusedIrEmitter::HandleConstant(HloInstruction* constant,
+ const Literal& literal) {
+ llvm::Constant* initializer =
+ llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_);
+ llvm::GlobalVariable* global = new llvm::GlobalVariable(
+ *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(),
+ /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer,
+ /*Name=*/"");
+ generators_[constant] = [=](const IrArray::Index& index) {
+ return IrArray(global, constant->shape())
+ .EmitReadArrayElement(index, ir_builder_);
+ };
+
+ return Status::OK();
+}
+
+Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) {
+ // Lookup ir value for 'operand'.
+ auto it = gte_values_.find(operand);
+ if (it == gte_values_.end()) {
+ return Unimplemented(
+ "GetTupleElement fusion currently only supports"
+ " parameter operands, but found operand: %s",
+ operand->name().c_str());
+ }
+ // Emit code to lookup tuple element pointer, and store it in 'gte_values_'.
+ llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement(
+ get_tuple_element->shape(), get_tuple_element->tuple_index(),
+ /*alignment=*/1, it->second, ir_builder_);
+ gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr));
+ // Emit code to read base tuple element array (if non-tuple shaped).
+ if (!ShapeUtil::IsTuple(get_tuple_element->shape())) {
+ generators_[get_tuple_element] =
+ [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
+ // TODO(b/34080002) Add aliasing information to tuple element IrArray.
+ return IrArray(tuple_element_ptr, get_tuple_element->shape())
+ .EmitReadArrayElement(index, ir_builder_);
+ };
+ }
+ return Status::OK();
+}
+
+Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
+ generators_[parameter] = [=](const IrArray::Index& index) {
+ return parameter_arrays_[parameter->parameter_number()]
+ .EmitReadArrayElement(index, ir_builder_);
+ };
+ // Store ir value for fusion operand associated with fusion parameter to be
+ // accessed by subsequent fused GetTupleElement instructions.
+ gte_values_.insert(std::make_pair(
+ parameter,
+ parameter_arrays_[parameter->parameter_number()].GetBasePointer()));
+ return Status::OK();
+}
+
+Status FusedIrEmitter::FinishVisit(HloInstruction* root) {
+ fused_root_ = root;
+ return tensorflow::Status::OK();
+}
+
+FusedIrEmitter::Generator FusedIrEmitter::GetRootGenerator() const {
+ CHECK_NE(nullptr, fused_root_)
+ << "GetRootGenerator should be called after Accept.";
+ return generators_.at(fused_root_);
+}
+
+FusedIrEmitter::Generator FusedIrEmitter::GetGenerator(
+ const HloInstruction* instruction) const {
+ return generators_.at(instruction);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
new file mode 100644
index 0000000000..303bb3ee6b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h
@@ -0,0 +1,94 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
+
+#include <map>
+#include <unordered_map>
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace xla {
+
+// Unlike IrEmitter, this creates host functions which emit IR to generate the
+// output element at the given index. It is used to generate fused operations.
+class FusedIrEmitter : public DfsHloVisitorWithDefault {
+ public:
+ using Generator = llvm_ir::ElementGenerator;
+
+ FusedIrEmitter(tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays,
+ ElementalIrEmitter* elemental_emitter)
+ : parameter_arrays_(parameter_arrays),
+ elemental_emitter_(elemental_emitter),
+ ir_builder_(elemental_emitter->ir_builder()) {}
+
+ Status DefaultAction(HloInstruction* hlo) override;
+
+ Status HandleConstant(HloInstruction* constant,
+ const Literal& literal) override;
+
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+
+ Status HandleParameter(HloInstruction* parameter) override;
+
+ Status FinishVisit(HloInstruction* root) override;
+
+ // Returns the generator function for the root of the fused computation.
+ Generator GetRootGenerator() const;
+
+ // Returns the generator function for the given instruction.
+ Generator GetGenerator(const HloInstruction* instruction) const;
+
+ private:
+ // Arrays of parameters of fusion instruction
+ tensorflow::gtl::ArraySlice<llvm_ir::IrArray> parameter_arrays_;
+
+ ElementalIrEmitter* elemental_emitter_;
+
+ // This member will be set by FinishVisit and used in GetRootGenerator.
+ const HloInstruction* fused_root_ = nullptr;
+
+ // Borrowed
+ llvm::IRBuilder<>* ir_builder_;
+
+ // Map from instruction pointers to functions to generate elements of their
+ // outputs
+ std::unordered_map<const HloInstruction*, Generator> generators_;
+
+ // Cache of generated values, lest we regenerate an element of a node with
+ // multiple outgoing edges
+ std::unordered_map<const HloInstruction*,
+ std::map<std::vector<llvm::Value*>, llvm::Value*>>
+ generated_value_cache_;
+
+ // Stores ir values required to emit fused (and possibly nested)
+ // GetTupleElement instructions.
+ std::unordered_map<const HloInstruction*, llvm::Value*> gte_values_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_FUSED_IR_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
new file mode 100644
index 0000000000..c095dea7a8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc
@@ -0,0 +1,274 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+
+#include "external/llvm/include/llvm/IR/Constants.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
+ llvm::IRBuilder<>* ir_builder)
+ : multidim_(ShapeUtil::Rank(shape)),
+ linear_(linear),
+ layout_(shape.layout()),
+ dims_(shape.dimensions().begin(), shape.dimensions().end()) {
+ CHECK(LayoutUtil::HasLayout(shape))
+ << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
+ << " should have a layout.";
+ int64 divisor = 1;
+ for (int64 dimension : layout_.minor_to_major()) {
+ int64 size_of_current_dimension = shape.dimensions(dimension);
+ // Emit IR instructions that compute
+ // (linear_index / divisor) % current_dimension
+ multidim_[dimension] = ir_builder->CreateURem(
+ ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)),
+ ir_builder->getInt64(size_of_current_dimension));
+ divisor *= size_of_current_dimension;
+ }
+}
+
+IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ llvm::Value* linear, const Shape& shape)
+ : multidim_(multidim.begin(), multidim.end()),
+ linear_(linear),
+ layout_(shape.layout()),
+ dims_(shape.dimensions().begin(), shape.dimensions().end()) {
+ CHECK_EQ(shape.dimensions_size(), multidim.size());
+ CHECK(LayoutUtil::HasLayout(shape))
+ << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
+ << " should have a layout.";
+}
+
+IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ const Shape& shape, llvm::IRBuilder<>* ir_builder)
+ : multidim_(multidim.begin(), multidim.end()),
+ layout_(shape.layout()),
+ dims_(shape.dimensions().begin(), shape.dimensions().end()) {
+ CHECK_EQ(shape.dimensions_size(), multidim.size());
+ CHECK(LayoutUtil::HasLayout(shape));
+ linear_ = Linearize(shape, ir_builder);
+}
+
+IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
+ : base_ptr_(base_ptr), shape_(&shape) {
+ TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
+ CHECK(base_ptr_->getType()->isPointerTy());
+ int depth = 0;
+ element_type_ =
+ llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
+ while (llvm::ArrayType* array_type =
+ llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
+ element_type_ = array_type->getElementType();
+ ++depth;
+ }
+
+ if (ShapeUtil::Rank(*shape_) == 0) {
+ DCHECK(depth == 1 || depth == 0) << depth;
+ } else {
+ DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString();
+ }
+}
+
+// Returns whether given linear index valid on given shape.
+bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
+ auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_);
+ *b.mutable_layout() = layout_;
+ return linear_ != nullptr &&
+ ContainersEqual(
+ ShapeUtil::StripDegenerateDimensions(a).dimensions(),
+ ShapeUtil::StripDegenerateDimensions(b).dimensions()) &&
+ LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(),
+ ShapeUtil::StripDegenerateDimensions(b).layout());
+}
+
+IrArray::Index IrArray::Index::SourceIndexOfReshape(
+ const Shape& output_shape, const Shape& input_shape,
+ llvm::IRBuilder<>* builder) const {
+ const auto& target_index = *this;
+ CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
+ llvm::Value* logical_linear_index = Linearize(output_shape, builder);
+ // Delinearizes logical_linear_index for the source array in row-major
+ // collapsed order. The first rank-1 indices are the remainder of the
+ // linear index by each dimension size.
+ std::vector<std::pair<int64, int64>> unmodified_dims =
+ ShapeUtil::DimensionsUnmodifiedByReshape(input_shape, output_shape);
+ std::vector<llvm::Value*> source_multidim_index(ShapeUtil::Rank(input_shape));
+ for (int64 i = ShapeUtil::Rank(input_shape) - 1; i >= 0; --i) {
+ auto divisor = builder->getInt64(input_shape.dimensions(i));
+ if (input_shape.dimensions(i) <= 1) {
+ source_multidim_index[i] = builder->getInt64(0);
+ } else {
+ // Search unmodified_dims for a pair whose first element is exactly "i".
+ //
+ // Because unmodified_dims are sorted by both "first" and "second", and
+ // "i" is monotonically decreasing, we avoid redundant searching by
+ // popping the back of unmodified_dims until the rear pair's first element
+ // <= i. If we stop precisely at "i", we find a match.
+ while (!unmodified_dims.empty() && unmodified_dims.back().first > i) {
+ unmodified_dims.pop_back();
+ }
+ if (!unmodified_dims.empty() && unmodified_dims.back().first == i) {
+ source_multidim_index[i] = target_index[unmodified_dims.back().second];
+ } else {
+ source_multidim_index[i] =
+ builder->CreateURem(logical_linear_index, divisor);
+ }
+ }
+ logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
+ }
+
+ if (linear() != nullptr &&
+ ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
+ return Index(source_multidim_index, linear(), input_shape);
+ }
+ return Index(source_multidim_index);
+}
+
+IrArray::Index IrArray::Index::SourceIndexOfTranspose(
+ const Shape& shape, const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const {
+ std::vector<llvm::Value*> operand_multidim_index =
+ Permute(dimension_mapping, multidim());
+ if (linear() != nullptr &&
+ ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
+ return Index(operand_multidim_index, linear(), operand_shape);
+ }
+ return Index(operand_multidim_index);
+}
+
+llvm::Value* IrArray::Index::Linearize(const Shape& shape,
+ llvm::IRBuilder<>* builder) const {
+ // Each dimension is multiplied by the product of the sizes of all
+ // earlier dimensions and added to the accumulator logical_linear_index.
+ llvm::Value* logical_linear_index = builder->getInt64(0);
+ int64 multiplier = 1;
+ for (ssize_t i = size() - 1; i >= 0; --i) {
+ llvm::Value* addend =
+ builder->CreateMul((*this)[i], builder->getInt64(multiplier), "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ multiplier *= shape.dimensions(i);
+ }
+ return logical_linear_index;
+}
+
+llvm::Value* IrArray::EmitArrayElementAddress(
+ const IrArray::Index& index, llvm::IRBuilder<>* ir_builder,
+ tensorflow::StringPiece name) const {
+ if (ShapeUtil::IsScalar(*shape_)) {
+ // Special handling of scalars: a scalar pretends to have the same value for
+ // every index, thus effectively implementing broadcasting of its value
+ // over higher-rank arrays.
+ return base_ptr_;
+ }
+ CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_));
+
+ std::vector<llvm::Value*> actual_index;
+ bool is_implicit_broadcast = false;
+ // We perform broadcasting when the operand shape has dimension(s) of size
+ // 1. In this case we fix the index value for that dimension to zero. This
+ // effectively broadcasts along this dimension.
+ for (int64 i = 0; i < index.size(); ++i) {
+ auto dim = shape_->dimensions(i);
+ actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
+ is_implicit_broadcast |= dim == 1;
+ }
+
+ if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
+ return ir_builder->CreateInBoundsGEP(
+ ir_builder->CreateBitCast(
+ base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder)
+ ->getPointerTo()),
+ {index.linear()}, llvm_ir::AsStringRef(name));
+ }
+
+ // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
+ // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
+ // should be computed by
+ //
+ // getelementptr base_ptr_, 0, most major index, ..., most minor index
+ std::vector<llvm::Value*> gep_indices(1, ir_builder->getInt64(0));
+ for (int64 i = shape_->layout().minor_to_major_size() - 1; i >= 0; --i) {
+ int64 dimension = shape_->layout().minor_to_major(i);
+ gep_indices.push_back(actual_index[dimension]);
+ }
+ return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices,
+ llvm_ir::AsStringRef(name));
+}
+
+llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
+ llvm::IRBuilder<>* ir_builder,
+ tensorflow::StringPiece name) const {
+ llvm::Value* element_address =
+ EmitArrayElementAddress(index, ir_builder, name);
+ llvm::LoadInst* load = ir_builder->CreateLoad(element_address);
+ llvm_ir::SetTbaaForInstruction(load, GetShape(),
+ /*is_pointer_to=*/false);
+ for (const std::pair<int, llvm::MDNode*>& kind_md_pair : metadata_) {
+ int kind = kind_md_pair.first;
+ llvm::MDNode* md = kind_md_pair.second;
+ load->setMetadata(kind, md);
+ }
+ return load;
+}
+
+void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
+ llvm::IRBuilder<>* ir_builder) const {
+ llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder);
+ llvm::StoreInst* store = ir_builder->CreateStore(value, element_address);
+ llvm_ir::SetTbaaForInstruction(store, GetShape(),
+ /*is_pointer_to=*/false);
+ for (const std::pair<int, llvm::MDNode*>& kind_md_pair : metadata_) {
+ int kind = kind_md_pair.first;
+ CHECK_NE(kind, llvm::LLVMContext::MD_invariant_load);
+ llvm::MDNode* md = kind_md_pair.second;
+ store->setMetadata(kind, md);
+ }
+}
+
+IrArray IrArray::CastToShape(const Shape& new_shape,
+ llvm::IRBuilder<>* ir_builder) const {
+ llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder);
+ return IrArray(
+ ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
+ new_shape);
+}
+
+/* static */ IrArray::Index IrArray::BumpIndex(const Index& index,
+ int64 which_dimension,
+ int64 addend,
+ llvm::IRBuilder<>* ir_builder) {
+ Index new_index = index;
+ new_index[which_dimension] = ir_builder->CreateAdd(
+ index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true,
+ /*HasNSW=*/true);
+ return new_index;
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.h b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
new file mode 100644
index 0000000000..0b182267c3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.h
@@ -0,0 +1,248 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
+
+#include <map>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// IrArray represents an XLA array at the LLVM IR level. This class
+// encapsulates a base pointer to the buffer holding the array (as an LLVM
+// Value) and the shape of the array. The class includes methods for emitting
+// LLVM IR sequences which access elements of the array at a multidimensional
+// index (eg, [x, y, z] in a 3-dimensional array). Arbitrary shape and layouts
+// are supported.
+class IrArray {
+ public:
+ // A multidimensional index into an IrArray. The index for dimension zero is
+ // first in the vector. This is the reverse order of the notation used for
+ // describing the dimensions of an array. That is, for a [4 x 3 x 2] array
+ // dimension zero has size 2, dimension one has size 3, and dimension two has
+ // size 4. Thus the index {1, 2, 3} indexes the last element of this [4 x 3 x
+ // 2] array.
+ //
+ // This may also keep a linear index and the layout and dimensions it was
+ // emitted for; if the shape where this `Index` is used matches, the linear
+ // index may be used, potentially sparing the cost of computing the
+ // multidimensional index, which LLVM DCE can delete.
+ class Index {
+ public:
+ // Constructs an empty zero-dimensional index.
+ Index() {}
+
+ // Constructs an index of rank "size". Each dimension of the index is
+ // initialized to "value".
+ explicit Index(size_t size, llvm::Value* value = nullptr)
+ : multidim_(size, value) {}
+
+ // Constructs an index from multi-dimensional index "multidim". The linear
+ // index is set to nullptr.
+ explicit Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim)
+ : multidim_(multidim.begin(), multidim.end()) {}
+
+ // Constructs an index from linear index "linear" and computes the
+ // multi-dimensional index from "linear" and "shape". "ir_builder" is the IR
+ // builder to emit the index of each dimension in the multi-dimensional
+ // index.
+ //
+ // Precondition: "shape" has a layout.
+ Index(llvm::Value* linear, const Shape& shape,
+ llvm::IRBuilder<>* ir_builder);
+
+ // Constructs an index from the given multi-dimensional index and the shape
+ // that it indexes into. Also, computes the linear index according to
+ // "shape".
+ //
+ // Precondition: "shape" has a layout.
+ Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ const Shape& shape, llvm::IRBuilder<>* ir_builder);
+
+ // Consturcts an index from both a multi-dimensional index and a linear
+ // index. "shape" has the same meaning as that in the constructor that takes
+ // only a linear index.
+ Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
+ llvm::Value* linear, const Shape& shape);
+
+ const std::vector<llvm::Value*>& multidim() const { return multidim_; }
+ llvm::Value* linear() const { return linear_; }
+
+ size_t size() const { return multidim().size(); }
+
+ llvm::Value* operator[](size_t i) const { return multidim()[i]; }
+ llvm::Value*& operator[](size_t i) { return multidim()[i]; }
+
+ void push_back(llvm::Value* value) { multidim().push_back(value); }
+
+ using iterator = std::vector<llvm::Value*>::iterator;
+ using const_iterator = std::vector<llvm::Value*>::const_iterator;
+
+ iterator begin() { return multidim().begin(); }
+ iterator end() { return multidim().end(); }
+
+ const_iterator begin() const { return multidim().begin(); }
+ const_iterator end() const { return multidim().end(); }
+
+ bool LinearValidOnShape(const Shape& a) const;
+
+ // Given that "this" is the target index of a reshape from `operand_shape`
+ // to `shape`, returns the source index.
+ Index SourceIndexOfReshape(const Shape& shape, const Shape& operand_shape,
+ llvm::IRBuilder<>* builder) const;
+
+ // Given that "this" is the target index of a transpose from `operand_shape`
+ // to `shape` with the given dimension mapping, returns the source index.
+ Index SourceIndexOfTranspose(
+ const Shape& shape, const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> dimension_mapping,
+ llvm::IRBuilder<>* builder) const;
+
+ // Linearizes the index into the given shape, i.e. reshapes it to rank-1 and
+ // returns the index into the sole dimension 0 of the new shape.
+ llvm::Value* Linearize(const Shape& shape,
+ llvm::IRBuilder<>* builder) const;
+
+ private:
+ // Changing the multi-dimensional index invalidates the linear index.
+ std::vector<llvm::Value*>& multidim() {
+ linear_ = nullptr;
+ return multidim_;
+ }
+
+ std::vector<llvm::Value*> multidim_;
+
+ // These values are purely for efficiency; `multidim_` is enough to find the
+ // element at a given `Index`, but if a loop is emitted with a linear index
+ // space, that linear index can be saved in `linear_`, and the layout and
+ // dimensions of the shape the loop was emitted for in `layout_` and
+ // `dims_`, and if the `Index` is used in another array, and its layout and
+ // dimensions match, the linear index can be used, sparing the cost of
+ // computing `multidim_`, which LLVM DCE could potentially so delete.
+ // Modifying `multidim_` after construction nullifies `linear_`, lest it
+ // be used wrongly, as it would be valid no more.
+ // If a loop is emitted with a multidimensional index space, `linear_` would
+ // be null and `layout_` and `dims_` would be ignored.
+ llvm::Value* linear_ = nullptr;
+ Layout layout_;
+ std::vector<int64> dims_;
+ };
+
+ // Default constructor. Constructs an IrArray in a null status.
+ IrArray() : base_ptr_(nullptr), shape_(nullptr) {}
+
+ // Construct an IrArray with the given base pointer and shape. base_ptr is a
+ // pointer type pointing to the first element(lowest address) of the array.
+ IrArray(llvm::Value* base_ptr, const Shape& shape);
+
+ // Default implementations of copying and moving.
+ IrArray(IrArray&& other) = default;
+ IrArray(const IrArray& other) = default;
+ IrArray& operator=(IrArray&& other) = default;
+ IrArray& operator=(const IrArray& other) = default;
+
+ llvm::Value* GetBasePointer() const { return base_ptr_; }
+ llvm::Type* GetElementLlvmType() const { return element_type_; }
+
+ const Shape& GetShape() const {
+ CHECK(shape_ != nullptr);
+ return *shape_;
+ }
+
+ // Emit a sequence of instructions to compute the address of the element in
+ // the given array at the given index. Returns the address of the element as
+ // an LLVM Value.
+ //
+ // The optional name is useful for debugging when looking at
+ // the emitted LLVM IR.
+ llvm::Value* EmitArrayElementAddress(const Index& index,
+ llvm::IRBuilder<>* ir_builder,
+ tensorflow::StringPiece name = "") const;
+
+ // Emit IR to read an array element at the given index. Returns the read
+ // result (effectively, a Value loaded from memory). This method seamlessly
+ // handles scalar shapes by broadcasting their value to all indices (index is
+ // ignored).
+ //
+ // The optional name is useful for debugging when looking at
+ // the emitted LLVM IR.
+ llvm::Value* EmitReadArrayElement(const Index& index,
+ llvm::IRBuilder<>* ir_builder,
+ tensorflow::StringPiece name = "") const;
+
+ // Emit IR to write the given value to the array element at the given index.
+ void EmitWriteArrayElement(const Index& index, llvm::Value* value,
+ llvm::IRBuilder<>* ir_builder) const;
+
+ // Returns a new IrArray whose shape is "new_shape" and base pointer is a
+ // bitcast of the base pointer of "this" IrArray.
+ IrArray CastToShape(const Shape& new_shape,
+ llvm::IRBuilder<>* ir_builder) const;
+
+ void AddAliasScopeMetadata(llvm::MDNode* alias_scope) {
+ AddMetadata(llvm::LLVMContext::MD_alias_scope, alias_scope);
+ }
+
+ void AddNoaliasMetadata(llvm::MDNode* noalias) {
+ AddMetadata(llvm::LLVMContext::MD_noalias, noalias);
+ }
+
+ void AddInvariantLoad(llvm::MDNode* invariant_load) {
+ AddMetadata(llvm::LLVMContext::MD_invariant_load, invariant_load);
+ }
+
+ // Bumps the "which_dimension" value within the provided index by the provided
+ // addend.
+ static Index BumpIndex(const Index& index, int64 which_dimension,
+ int64 addend, llvm::IRBuilder<>* ir_builder);
+
+ private:
+ // Add the specified LLVM IR metadata to loads/stores associated with this
+ // IrArray.
+ void AddMetadata(int kind, llvm::MDNode* md) {
+ InsertOrDie(&metadata_, kind, md);
+ }
+
+ // Address of the base of the array as an LLVM Value.
+ llvm::Value* base_ptr_;
+
+ // The LLVM type of the elements in the array.
+ llvm::Type* element_type_;
+
+ // Shape of the XLA array.
+ const Shape* shape_;
+
+ // The list of key/value pairs used when attaching metadata to emitted
+ // loads/stores for this array. They keys are the metadata kinds and the
+ // values are the metadata nodes.
+ std::map<int, llvm::MDNode*> metadata_;
+};
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_IR_ARRAY_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
new file mode 100644
index 0000000000..4ccded61e7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -0,0 +1,197 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+
+#include <numeric>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/Constants.h"
+#include "external/llvm/include/llvm/IR/Function.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace llvm_ir {
+
+ForLoop::ForLoop(tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index, llvm::Value* step)
+ : suffix_(suffix.ToString()),
+ start_index_(start_index),
+ end_index_(end_index),
+ step_(step),
+ insert_before_bb_(nullptr) {}
+
+/* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
+ tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index, llvm::Value* step, llvm::IRBuilder<>* ir_builder) {
+ std::unique_ptr<ForLoop> loop(
+ new ForLoop(suffix, start_index, end_index, step));
+ loop->Emit(ir_builder);
+ return loop;
+}
+
+void ForLoop::Emit(llvm::IRBuilder<>* ir_builder) {
+ // The preheader block is the block the builder is currently emitting
+ // code into.
+ preheader_bb_ = ir_builder->GetInsertBlock();
+
+ llvm::BasicBlock::iterator insert_point = ir_builder->GetInsertPoint();
+ if (insert_point == preheader_bb_->end()) {
+ // We're emitting the loop at the end of a basic block. Verify there is no
+ // terminator (eg, branch) in the basic block.
+ CHECK_EQ(nullptr, preheader_bb_->getTerminator());
+
+ exit_bb_ = CreateBasicBlockWithSuffix("loop_exit", ir_builder);
+ } else {
+ // We're emitting the loop into the middle of a basic block. splitBasicBlock
+ // requires that this basic block be well-formed (have a terminator).
+ CHECK_NE(nullptr, preheader_bb_->getTerminator());
+
+ // Split the preheader to create an exit basic block. The exit basic block
+ // will contain all instructions at or after insert_point.
+ exit_bb_ = preheader_bb_->splitBasicBlock(
+ insert_point, GetNameWithSuffix("loop_exit").c_str());
+
+ // splitBasicBlock adds an unconditional branch between the split basic
+ // blocks. Remove it. An unconditional branch will be added below from the
+ // preheader to the header.
+ preheader_bb_->getTerminator()->eraseFromParent();
+ }
+ insert_before_bb_ = exit_bb_;
+
+ // Create remaining basic block which form the inside of the loop.
+ header_bb_ = CreateBasicBlockWithSuffix("loop_header", ir_builder);
+ body_bb_ = CreateBasicBlockWithSuffix("loop_body", ir_builder);
+
+ // Function entry basic block.
+ // Emit alloca for the induction variable. We do this at the entry to the
+ // basic block to ensure the alloc only executes once per function (we could
+ // be emitting a nested loop).
+ llvm::Function* func = preheader_bb_->getParent();
+ ir_builder->SetInsertPoint(&func->getEntryBlock(),
+ func->getEntryBlock().getFirstInsertionPt());
+ llvm::Value* indvar_address =
+ ir_builder->CreateAlloca(ir_builder->getInt64Ty(), nullptr,
+ GetNameWithSuffix("invar_address").c_str());
+
+ // Preheader basic block.
+ // Initialize induction variable starting index. Create branch to the header.
+ ir_builder->SetInsertPoint(preheader_bb_);
+ ir_builder->CreateStore(start_index_, indvar_address);
+ // The preheader should not have a branch yet.
+ CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
+ ir_builder->CreateBr(header_bb_);
+
+ // Header basic block.
+ // Emit the loop conditional branch. Load and compare indvar with ending
+ // index and jump to loop exit if equal. Jump to body otherwise.
+ ir_builder->SetInsertPoint(header_bb_);
+ indvar_ = ir_builder->CreateLoad(indvar_address,
+ GetNameWithSuffix("indvar").c_str());
+ llvm::Value* exit_cond = ir_builder->CreateICmpUGE(indvar_, end_index_);
+ ir_builder->CreateCondBr(/*Cond=*/exit_cond,
+ /*True=*/exit_bb_, /*False=*/body_bb_);
+
+ // Body basic block.
+ // Increment indvar, store indvar, and jump to header.
+ ir_builder->SetInsertPoint(body_bb_);
+ llvm::Value* step = step_;
+ llvm::Value* indvar = indvar_;
+
+ llvm::Value* indvar_inc =
+ ir_builder->CreateAdd(indvar, step, "invar.inc",
+ /*HasNUW=*/true, /*HasNSW=*/true);
+ ir_builder->CreateStore(indvar_inc, indvar_address);
+ ir_builder->CreateBr(header_bb_);
+
+ // Re-point the IR builder to the loop exit block.
+ ir_builder->SetInsertPoint(exit_bb_);
+}
+
+string ForLoop::GetNameWithSuffix(tensorflow::StringPiece name) {
+ if (suffix_.empty()) {
+ return name.ToString();
+ } else {
+ return tensorflow::strings::StrCat(name, ".", suffix_);
+ }
+}
+
+llvm::BasicBlock* ForLoop::CreateBasicBlockWithSuffix(
+ tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) {
+ return CreateBasicBlock(insert_before_bb_, GetNameWithSuffix(name),
+ ir_builder);
+}
+
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(tensorflow::StringPiece suffix,
+ llvm::Value* start_index,
+ llvm::Value* end_index) {
+ if (inner_loop_body_bb_ != nullptr) {
+ // Create this loop inside the previous one.
+ ir_builder_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
+ }
+ std::unique_ptr<ForLoop> loop = ForLoop::EmitForLoop(
+ suffix, start_index, end_index, ir_builder_->getInt64(1), ir_builder_);
+
+ if (outer_loop_preheader_bb_ == nullptr) {
+ outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
+ }
+
+ if (outer_loop_exit_bb_ == nullptr) {
+ outer_loop_exit_bb_ = loop->GetExitBasicBlock();
+ }
+
+ inner_loop_body_bb_ = loop->GetBodyBasicBlock();
+
+ return loop;
+}
+
+std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
+ int64 end_index,
+ tensorflow::StringPiece suffix) {
+ CHECK_LE(start_index, end_index);
+ return AddLoop(suffix, ir_builder_->getInt64(start_index),
+ ir_builder_->getInt64(end_index));
+}
+
+IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
+ tensorflow::StringPiece suffix) {
+ std::vector<int64> dimensions(ShapeUtil::Rank(shape));
+ std::iota(dimensions.begin(), dimensions.end(), 0);
+ return AddLoopsForShapeOnDimensions(shape, dimensions, suffix);
+}
+
+IrArray::Index ForLoopNest::AddLoopsForShapeOnDimensions(
+ const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::StringPiece suffix) {
+ llvm_ir::IrArray::Index index(shape.dimensions_size(), nullptr);
+ for (int64 dimension : dimensions) {
+ std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/shape.dimensions(dimension),
+ /*suffix=*/tensorflow::strings::Printf(
+ "%s%lld", suffix.ToString().c_str(), dimension));
+ index[dimension] = loop->GetIndVarValue();
+ }
+ return index;
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
new file mode 100644
index 0000000000..0cc82b040d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h
@@ -0,0 +1,230 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_LOOP_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_LOOP_H_
+
+#include <memory>
+#include <string>
+
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// A class for constructing a for-loop in LLVM IR.
+class ForLoop {
+ public:
+ // Emit a for-loop at the current insert point of the given IRBuilder.
+ //
+ // start_index and end_index are the loop bounds (end_index is not inclusive).
+ // `step` is the increment of the loop index after each iteration.
+ //
+ // The current insert basic block of the builder is the preheader to the loop
+ // (see below for definition of basic block names). All instructions (if any)
+ // at or after the insert point in the insert basic block are moved to a newly
+ // created exit basic block. Instructions before the insert point remain in
+ // the insert BB:
+ //
+ // /--------------\ /----------------\
+ // | insert BB | | insert BB |
+ // | ... | | (preheader BB) |
+ // | %foo = ... | | ... |
+ // insert point ->| %bar = ... | ===> | %foo = ... |
+ // | ... | \----------------/
+ // \--------------/ |
+ // V
+ // [[ LOOP BBs ]]
+ // |
+ // V
+ // /--------------\
+ // | exit BB |
+ // | %bar = ... |
+ // | ... |
+ // \--------------/
+ //
+ // `suffix` is a string used to disambiguate variable and basic block names
+ // emitted in LLVM IR. This string is appended to the name of the induction
+ // variable value and each basic block created for the loop. The builder
+ // insert point is set to the end of the exit block after the function
+ // returns.
+ static std::unique_ptr<ForLoop> EmitForLoop(tensorflow::StringPiece suffix,
+ llvm::Value* start_index,
+ llvm::Value* end_index,
+ llvm::Value* step,
+ llvm::IRBuilder<>* ir_builder);
+
+ // The names of the blocks follow LLVM's conventions. Control flow amongst the
+ // blocks for the example C code looks like:
+ //
+ // for (int i = 0; i < n; ++i) {
+ // do_stuff(i);
+ // }
+ //
+ // /--------------\
+ // | preheader BB |
+ // | i = 0 |
+ // \--------------/
+ // |
+ // V
+ // /-------------\
+ // | header BB |<-+
+ // | if i < n: | |
+ // | goto body | |
+ // | else: | |
+ // | goto exit | |
+ // \-------------/ |
+ // | | |
+ // +--------+ | |
+ // | V |
+ // | /-------------\ |
+ // | | body BB | |
+ // | | dostuff(i) |--+
+ // | | ++i |
+ // | \-------------/
+ // |
+ // | /-------------\
+ // +->| exit BB |
+ // \-------------/
+ //
+ // Caller-emitted code to execute within the loop should be placed within the
+ // "body" basic block.
+ //
+ // Return pointers to various blocks in the loop.
+ llvm::BasicBlock* GetPreheaderBasicBlock() const { return preheader_bb_; }
+ llvm::BasicBlock* GetHeaderBasicBlock() const { return header_bb_; }
+ llvm::BasicBlock* GetBodyBasicBlock() const { return body_bb_; }
+ llvm::BasicBlock* GetExitBasicBlock() const { return exit_bb_; }
+
+ // Return the Value representing the induction variable in the body basic
+ // block of the loop.
+ llvm::Value* GetIndVarValue() const { return indvar_; }
+
+ private:
+ ForLoop(tensorflow::StringPiece suffix, llvm::Value* start_index,
+ llvm::Value* end_index, llvm::Value* step);
+
+ // Emit the loop at the insert point of the builder.
+ void Emit(llvm::IRBuilder<>* ir_builder);
+
+ llvm::BasicBlock* CreateBasicBlockWithSuffix(tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder);
+
+ // Create a name for an LLVM construct appending the member suffix_ if it is
+ // set.
+ string GetNameWithSuffix(tensorflow::StringPiece name);
+
+ string suffix_;
+ llvm::Value* start_index_;
+ llvm::Value* end_index_;
+ llvm::Value* step_;
+
+ // To improve readability of the IR, we want the basic blocks to appear
+ // consecutively in the following order: preheader, header, body, loop,
+ // exit. The member insert_before_bb_ points to where the next basic block
+ // should be created to ensure this ordering.
+ llvm::BasicBlock* insert_before_bb_;
+
+ llvm::BasicBlock* preheader_bb_;
+ llvm::BasicBlock* header_bb_;
+ llvm::BasicBlock* body_bb_;
+ llvm::BasicBlock* exit_bb_;
+ llvm::Value* indvar_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ForLoop);
+};
+
+// A simple class for constructing nested for-loops.
+class ForLoopNest {
+ public:
+ explicit ForLoopNest(llvm::IRBuilder<>* ir_builder)
+ : outer_loop_preheader_bb_(nullptr),
+ outer_loop_exit_bb_(nullptr),
+ inner_loop_body_bb_(nullptr),
+ ir_builder_(ir_builder) {}
+
+ // Adds a loop to the nest. If no loop has been added yet then emit a loop at
+ // the current insert point of the given builder. If one or more loops have
+ // been added then emit loop inside the body of the last added loop.
+ std::unique_ptr<ForLoop> AddLoop(tensorflow::StringPiece suffix,
+ llvm::Value* start_index,
+ llvm::Value* end_index);
+
+ // A convenient wrapper of the other flavor of AddLoop. The given start and
+ // end index are constant.
+ std::unique_ptr<ForLoop> AddLoop(int64 start_index, int64 end_index,
+ tensorflow::StringPiece suffix);
+
+ // Add loops to iterate through the indices within the specified
+ // shape. The returned index collects the induction variables of the
+ // loops so that it will iterate through all coordinates within the
+ // specified shape.
+ //
+ // E.g. if you pass in a 2x3 shape, you will get back an index with
+ // two entries that are induction variables of the two loops that
+ // will be added. That index will iterate through the 6 coordinates
+ // within the shape. One possible order for that sequence would be:
+ //
+ // (0,0), (0,1), (0,2), (1,0), (1,1), (1,2)
+ IrArray::Index AddLoopsForShape(const Shape& shape,
+ tensorflow::StringPiece suffix);
+
+ // Add a loop for each dimension in "dimensions". "suffix" is the
+ // name suffix of the indvar and basic blocks in this new loop nest.
+ //
+ // The return value is an index with the induction variables. The
+ // size equals the rank of shape and there is a null for each
+ // dimension that is not in "dimensions".
+ IrArray::Index AddLoopsForShapeOnDimensions(
+ const Shape& shape, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::StringPiece suffix);
+
+ // Convenience methods which return particular basic blocks of the outermost
+ // or innermost loops. These methods return nullptr if no loops have been
+ // added yet.
+ llvm::BasicBlock* GetOuterLoopPreheaderBasicBlock() {
+ return outer_loop_preheader_bb_;
+ }
+ llvm::BasicBlock* GetOuterLoopExitBasicBlock() { return outer_loop_exit_bb_; }
+ llvm::BasicBlock* GetInnerLoopBodyBasicBlock() { return inner_loop_body_bb_; }
+
+ private:
+ // The preheader and exit basic block of the outermost loop, or nullptr if no
+ // loop has been added yet.
+ llvm::BasicBlock* outer_loop_preheader_bb_;
+ llvm::BasicBlock* outer_loop_exit_bb_;
+
+ // The body basic block of the most-recently added loop, or nullptr if no loop
+ // has been added yet.
+ llvm::BasicBlock* inner_loop_body_bb_;
+
+ llvm::IRBuilder<>* ir_builder_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ForLoopNest);
+};
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_LOOP_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
new file mode 100644
index 0000000000..d7a231db61
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc
@@ -0,0 +1,471 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+
+#include <algorithm>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/MDBuilder.h"
+#include "external/llvm/include/llvm/IR/Operator.h"
+#include "external/llvm/include/llvm/Target/TargetOptions.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+string AsString(const std::string& str) {
+ return string(str.data(), str.length());
+}
+
+llvm::StringRef AsStringRef(tensorflow::StringPiece str) {
+ return llvm::StringRef(str.data(), str.size());
+}
+
+string DumpModuleToString(const llvm::Module& module) {
+ std::string buffer_string;
+ llvm::raw_string_ostream ostream(buffer_string);
+ module.print(ostream, nullptr);
+ ostream.flush();
+ return AsString(buffer_string);
+}
+
+llvm::Value* EmitCallToIntrinsic(
+ llvm::Intrinsic::ID intrinsic_id,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
+ llvm::IRBuilder<>* ir_builder) {
+ std::vector<llvm::Type*> types;
+ for (auto type : overloaded_types) {
+ types.push_back(type);
+ }
+ llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
+ llvm::Function* intrinsic =
+ llvm::Intrinsic::getDeclaration(module, intrinsic_id, types);
+ std::vector<llvm::Value*> operands_vec;
+ for (auto operand : operands) {
+ operands_vec.push_back(operand);
+ }
+ return ir_builder->CreateCall(intrinsic, operands_vec);
+}
+
+llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::Type* array_type = array->getType();
+ CHECK(array_type->isPointerTy());
+ llvm::PointerType* array_type_as_pointer =
+ llvm::cast<llvm::PointerType>(array_type);
+ VLOG(2) << "EmitBufferIndexingGEP with type="
+ << llvm_ir::DumpToString(*array_type)
+ << " array=" << llvm_ir::DumpToString(*array)
+ << " index=" << llvm_ir::DumpToString(*index);
+
+ return ir_builder->CreateInBoundsGEP(
+ array_type_as_pointer->getElementType(), array,
+ llvm::isa<llvm::GlobalVariable>(array)
+ ? llvm::ArrayRef<llvm::Value*>({ir_builder->getInt64(0), index})
+ : index);
+}
+
+llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
+ llvm::IRBuilder<>* ir_builder) {
+ return EmitBufferIndexingGEP(array, ir_builder->getInt64(index), ir_builder);
+}
+
+llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
+ llvm::IRBuilder<>* ir_builder) {
+ switch (element_type) {
+ case PRED:
+ case S8:
+ case U8:
+ return ir_builder->getInt8Ty();
+ case S16:
+ case U16:
+ return ir_builder->getInt16Ty();
+ case S32:
+ case U32:
+ return ir_builder->getInt32Ty();
+ case S64:
+ case U64:
+ return ir_builder->getInt64Ty();
+ case F32:
+ return ir_builder->getFloatTy();
+ case F64:
+ return ir_builder->getDoubleTy();
+ // A Tuple contains an array of pointers. Use i8*.
+ case TUPLE:
+ // An Opaque is like a void*, use i8*.
+ case OPAQUE:
+ return ir_builder->getInt8PtrTy();
+ default:
+ LOG(FATAL) << "unsupported type " << element_type;
+ }
+}
+
+llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) {
+ llvm::Type* result_type =
+ PrimitiveTypeToIrType(shape.element_type(), ir_builder);
+ if (ShapeUtil::IsTuple(shape)) {
+ // A tuple buffer is an array of pointers.
+ result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size());
+ } else {
+ for (int64 dimension : shape.layout().minor_to_major()) {
+ result_type =
+ llvm::ArrayType::get(result_type, shape.dimensions(dimension));
+ }
+ }
+ return result_type;
+}
+
+namespace {
+
+// Recursively construct a multidimensional LLVM constant which represents the
+// given literal. The minor-to-major dimension ordering in the constant matches
+// that of the literal. For example, given a [2 x 3 x 4] Literal (dimension 0
+// has size 4, dimension 1 has size 3, etc) of primitive type F32 with a
+// minor_to_major value of [2, 1, 0] (column major), a LLVM constant of type
+// [4 x [3 x [2 x float]] will be returned.
+//
+// multi_index is a multidimensional index into the array. dimension_index is an
+// index into the minor_to_major field in the literal shape. This determines
+// which dimension is iterated over in this level of the recursion. Dimensions
+// are iterated from most major down to most minor (highest dimension_index
+// value down to zero).
+llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index,
+ std::vector<int64>* multi_index,
+ llvm::IRBuilder<>* ir_builder) {
+ const Shape& shape = literal.shape();
+ llvm::Type* ir_element_type =
+ llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder);
+ if (dimension_index == -1) {
+ // Base case of the recursion. Index into the data field of the protobuf
+ // with the multi index.
+ llvm::Constant* value;
+ switch (shape.element_type()) {
+ case PRED:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<bool>(literal, *multi_index));
+ break;
+ case U8:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<uint8>(literal, *multi_index));
+ break;
+ case S32:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<int32>(literal, *multi_index));
+ break;
+ case U32:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<uint32>(literal, *multi_index));
+ break;
+ case S64:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<int64>(literal, *multi_index));
+ break;
+ case U64:
+ value = llvm::ConstantInt::get(
+ ir_element_type, LiteralUtil::Get<uint64>(literal, *multi_index));
+ break;
+ case F32:
+ value = llvm::ConstantFP::get(
+ ir_element_type, LiteralUtil::Get<float>(literal, *multi_index));
+ break;
+ case F64:
+ value = llvm::ConstantFP::get(
+ ir_element_type, LiteralUtil::Get<double>(literal, *multi_index));
+ break;
+ default:
+ LOG(FATAL) << "unsupported type " << shape.element_type();
+ }
+ return value;
+ }
+
+ // The dimension index starts at the one less than the rank of the array and
+ // decrements with each recursive call. We want to iterate through the
+ // dimensions in major-to-minor order as we recurse so just index into
+ // minor_to_major to get the dimension number for this level of the recursion.
+ int64 dimension = shape.layout().minor_to_major(dimension_index);
+
+ // Recursively call LiteralToConstant to construct subarrays for the
+ // more-minor dimensions. Gather the subarrays into a vector for bundling into
+ // a new (higher-dimensional) ConstantArray.
+ std::vector<llvm::Constant*> elements;
+ for (int64 i = 0; i < shape.dimensions(dimension); ++i) {
+ (*multi_index)[dimension] = i;
+ elements.push_back(LiteralToConstant(literal, dimension_index - 1,
+ multi_index, ir_builder));
+ }
+
+ llvm::Type* element_type;
+ if (elements.empty()) {
+ element_type = ir_element_type;
+ for (int i = 0; i < dimension_index; ++i) {
+ int64 index = shape.layout().minor_to_major(i);
+ element_type =
+ llvm::ArrayType::get(element_type, shape.dimensions(index));
+ }
+ } else {
+ element_type = elements[0]->getType();
+ }
+ llvm::ArrayType* aggregate_type =
+ llvm::ArrayType::get(element_type, shape.dimensions(dimension));
+ return llvm::ConstantArray::get(aggregate_type, elements);
+}
+
+} // namespace
+
+llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
+ llvm::IRBuilder<>* ir_builder) {
+ std::vector<int64> multi_index(ShapeUtil::Rank(literal.shape()), 0);
+ llvm::Constant* value = LiteralToConstant(
+ literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1,
+ &multi_index, ir_builder);
+ return value;
+}
+
+llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder,
+ int alignment) {
+ return EmitAllocaAtFunctionEntryWithCount(type, nullptr, name, ir_builder,
+ alignment);
+}
+
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
+ llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder, int alignment) {
+ llvm::IRBuilder<>::InsertPoint insert_point = ir_builder->saveIP();
+ llvm::Function* function = ir_builder->GetInsertBlock()->getParent();
+ ir_builder->SetInsertPoint(&function->getEntryBlock(),
+ function->getEntryBlock().getFirstInsertionPt());
+ llvm::AllocaInst* alloca =
+ ir_builder->CreateAlloca(type, element_count, AsStringRef(name));
+ if (alignment != 0) {
+ alloca->setAlignment(alignment);
+ }
+ ir_builder->restoreIP(insert_point);
+ return alloca;
+}
+
+llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder) {
+ return llvm::BasicBlock::Create(
+ /*Context=*/ir_builder->getContext(),
+ /*Name=*/AsStringRef(name),
+ /*Parent=*/ir_builder->GetInsertBlock()->getParent(),
+ /*InsertBefore*/ insert_before);
+}
+
+LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder, bool emit_else) {
+ llvm_ir::LlvmIfData if_data;
+ if_data.if_block = ir_builder->GetInsertBlock();
+ if_data.true_block = CreateBasicBlock(
+ nullptr, tensorflow::strings::StrCat(name, "-true"), ir_builder);
+ if_data.false_block =
+ emit_else ? CreateBasicBlock(nullptr,
+ tensorflow::strings::StrCat(name, "-false"),
+ ir_builder)
+ : nullptr;
+
+ // There is no reason this function cannot work without a
+ // terminator, that is just a different case that has not been
+ // implemented yet. It is a different case because splitBasicBlock
+ // requires a terminator.
+ CHECK_NE(nullptr, if_data.if_block->getTerminator());
+ if_data.after_block = if_data.if_block->splitBasicBlock(
+ ir_builder->GetInsertPoint(),
+ AsStringRef(tensorflow::strings::StrCat(name, "-after")));
+
+ // splitBasicBlock inserts an unconditional terminator that we have
+ // to remove as we want a conditional branch there.
+ if_data.if_block->getTerminator()->eraseFromParent();
+
+ ir_builder->SetInsertPoint(if_data.if_block);
+ ir_builder->CreateCondBr(
+ condition, if_data.true_block,
+ emit_else ? if_data.false_block : if_data.after_block);
+
+ ir_builder->SetInsertPoint(if_data.true_block);
+ ir_builder->CreateBr(if_data.after_block);
+
+ if (emit_else) {
+ ir_builder->SetInsertPoint(if_data.false_block);
+ ir_builder->CreateBr(if_data.after_block);
+ }
+
+ ir_builder->SetInsertPoint(if_data.after_block,
+ if_data.after_block->getFirstInsertionPt());
+
+ return if_data;
+}
+
+llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
+ llvm::Value* lhs_value, llvm::Value* rhs_value,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::Value* comparison_result;
+ if (lhs_value->getType()->isIntegerTy()) {
+ comparison_result = ir_builder->CreateICmp(predicate, lhs_value, rhs_value);
+ } else {
+ comparison_result = ir_builder->CreateFCmp(predicate, lhs_value, rhs_value);
+ }
+ // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1
+ // arrays. So we extend it to i8 so that it's addressable.
+ return ir_builder->CreateZExt(
+ comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder));
+}
+
+// Internal helper that is called from emitted code to log an int64 value with a
+// tag.
+static void LogS64(const char* tag, int64 value) {
+ LOG(INFO) << tag << " (int64): " << value;
+}
+
+void EmitLogging(const char* tag, llvm::Value* value,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::FunctionType* log_function_type = llvm::FunctionType::get(
+ ir_builder->getVoidTy(),
+ {ir_builder->getInt64Ty(), ir_builder->getInt64Ty()}, /*isVarArg=*/false);
+ ir_builder->CreateCall(
+ log_function_type,
+ ir_builder->CreateIntToPtr(
+ ir_builder->getInt64(tensorflow::bit_cast<int64>(&LogS64)),
+ log_function_type->getPointerTo()),
+ {ir_builder->getInt64(tensorflow::bit_cast<int64>(tag)), value});
+}
+
+void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape,
+ bool is_pointer_to) {
+ legacy_flags::LlvmUtilFlags* flags = legacy_flags::GetLlvmUtilFlags();
+ if (!flags->xla_emit_tbaa) {
+ return;
+ }
+
+ llvm::MDBuilder metadata_builder(instruction->getContext());
+ llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA");
+ string type_name;
+ if (is_pointer_to) {
+ type_name += "pointer-to ";
+ }
+ // Scalars do not have layout which makes it permissible to omit an explicit
+ // layout. To make sure that equivalent scalar shapes have the same TBAA,
+ // remove the (meaningless) explicit layout if one is present.
+ if (ShapeUtil::Rank(shape) == 0) {
+ LayoutUtil::ClearLayout(&shape);
+ } else {
+ CHECK(shape.has_layout());
+ }
+ type_name += shape.ShortDebugString();
+ llvm::MDNode* tbaa_node =
+ metadata_builder.createTBAANode(llvm_ir::AsStringRef(type_name), root);
+ instruction->setMetadata(llvm::LLVMContext::MD_tbaa,
+ metadata_builder.createTBAAStructTagNode(
+ tbaa_node, tbaa_node, /*Offset=*/0));
+}
+
+void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) {
+ llvm::LLVMContext& context = load->getContext();
+ llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
+ llvm::Constant* alignment_constant =
+ llvm::ConstantInt::get(int64_ty, alignment);
+ llvm::MDBuilder metadata_builder(context);
+ auto* alignment_metadata =
+ metadata_builder.createConstant(alignment_constant);
+ load->setMetadata(llvm::LLVMContext::MD_align,
+ llvm::MDNode::get(context, alignment_metadata));
+}
+
+void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ uint64_t dereferenceable_bytes) {
+ llvm::LLVMContext& context = load->getContext();
+ llvm::Type* int64_ty = llvm::Type::getInt64Ty(context);
+ llvm::Constant* dereferenceable_bytes_constant =
+ llvm::ConstantInt::get(int64_ty, dereferenceable_bytes);
+ llvm::MDBuilder metadata_builder(context);
+ auto* dereferenceable_bytes_metadata =
+ metadata_builder.createConstant(dereferenceable_bytes_constant);
+ load->setMetadata(llvm::LLVMContext::MD_dereferenceable,
+ llvm::MDNode::get(context, dereferenceable_bytes_metadata));
+}
+
+llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
+ llvm::Instruction* inst) {
+ llvm::LLVMContext& context = inst->getParent()->getContext();
+ llvm::IntegerType* i32 = llvm::Type::getInt32Ty(context);
+ inst->setMetadata(
+ llvm::LLVMContext::MD_range,
+ llvm::MDNode::get(
+ context,
+ {llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, lower)),
+ llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(i32, upper))}));
+ return inst;
+}
+
+string SanitizeIrName(string function_name) {
+ // Replace some characters that cannot occur in LLVM names with '_'
+ std::replace(function_name.begin(), function_name.end(), '.', '_');
+ std::replace(function_name.begin(), function_name.end(), '%', '_');
+ std::replace(function_name.begin(), function_name.end(), '-', '_');
+ return function_name;
+}
+
+void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder) {
+ builder->SetInsertPoint(blk, blk->getFirstInsertionPt());
+}
+
+llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
+ llvm::IRBuilder<>* builder) {
+ auto size = rotand->getType()->getPrimitiveSizeInBits();
+ auto size_value = builder->getIntN(size, size);
+ auto mod = [=](llvm::Value* x) { return builder->CreateURem(x, size_value); };
+ return builder->CreateOr(
+ builder->CreateShl(rotand, mod(builder->CreateSub(size_value, rotor))),
+ builder->CreateLShr(rotand, mod(rotor)));
+}
+
+int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) {
+ unsigned pointer_size = data_layout.getPointerSize();
+ return ShapeUtil::ByteSizeOf(shape, pointer_size);
+}
+
+void SetFastMathFlags(llvm::FastMathFlags* fast_math_flags) {
+ auto* flags = legacy_flags::GetLlvmBackendFlags();
+ if (flags->xla_precision_losing_optimizations) {
+ fast_math_flags->setAllowReciprocal();
+ }
+ if (flags->xla_fast_math) {
+ fast_math_flags->setUnsafeAlgebra();
+ }
+}
+
+void SetTargetOptions(llvm::TargetOptions* options) {
+ auto* flags = legacy_flags::GetLlvmBackendFlags();
+ options->LessPreciseFPMADOption = options->UnsafeFPMath =
+ flags->xla_fast_math || flags->xla_precision_losing_optimizations;
+ options->NoInfsFPMath = options->NoNaNsFPMath = flags->xla_fast_math;
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
new file mode 100644
index 0000000000..56c26b3800
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h
@@ -0,0 +1,228 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
+
+#include <stdint.h>
+#include <string>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/StringRef.h"
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "external/llvm/include/llvm/IR/Module.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "external/llvm/include/llvm/Support/raw_ostream.h"
+#include "tensorflow/compiler/xla/service/buffer_assignment.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace llvm {
+class FastMathFlags;
+class TargetOptions;
+};
+
+namespace xla {
+namespace llvm_ir {
+
+// Convert a std::string (used by LLVM's interfaces) to string.
+string AsString(const std::string& str);
+
+// Convert a tensorflow::StringPiece to a llvm::StringRef. Note: both
+// tensorflow::StringPiece and llvm::StringRef are non-owning pointers into a
+// string in memory. This method is used to feed strings to LLVM
+// & Clang APIs that expect llvm::StringRef.
+llvm::StringRef AsStringRef(tensorflow::StringPiece str);
+
+template <typename T>
+llvm::ArrayRef<T> AsArrayRef(const std::vector<T>& vec) {
+ return llvm::ArrayRef<T>(vec.data(), vec.size());
+}
+
+template <typename T>
+llvm::ArrayRef<T> AsArrayRef(const tensorflow::gtl::ArraySlice<T>& slice) {
+ return llvm::ArrayRef<T>(slice.data(), slice.size());
+}
+
+// Dump the given LLVM entity to a string. This works for Types and Values.
+template <typename T>
+string DumpToString(const T& entity) {
+ std::string buffer_string;
+ llvm::raw_string_ostream ostream(buffer_string);
+ entity.print(ostream);
+ ostream.flush();
+ return AsString(buffer_string);
+}
+
+// Dump the given LLVM module to a string. This requires a function distinct
+// from DumpToString because the signatures of the print() methods for Values
+// and Modules are slightly different.
+string DumpModuleToString(const llvm::Module& module);
+
+// Sanitizes the given name to be a valid LLVM IR value name.
+string SanitizeIrName(string name);
+
+// Emits a call to the specified intrinsic with the given operands. Overloaded
+// intrinsics (for example, "minnum") must include a type in overloaded_types
+// for each overloaded type. Typically, overloaded intrinsics have only a single
+// overloaded type.
+llvm::Value* EmitCallToIntrinsic(
+ llvm::Intrinsic::ID intrinsic_id,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ tensorflow::gtl::ArraySlice<llvm::Type*> overloaded_types,
+ llvm::IRBuilder<>* ir_builder);
+
+// Convenience methods for emitting a GEP instruction that indexes into a buffer
+// (1-dimensional array), equivalent to array[index]. The type is automatically
+// determined from the element type of the array. The int64 index overload
+// wraps the index in a i64 llvm::Value.
+llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, llvm::Value* index,
+ llvm::IRBuilder<>* ir_builder);
+llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index,
+ llvm::IRBuilder<>* ir_builder);
+
+// Returns the LLVM type which represents the given XLA primitive type.
+llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type,
+ llvm::IRBuilder<>* ir_builder);
+
+// Returns the LLVM type which represents the given XLA shape. For example,
+// if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]].
+llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder);
+
+// Converts a given literal to an IR Constant. Literals have known constant
+// values at IR emission time.
+llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
+ llvm::IRBuilder<>* ir_builder);
+
+// Inserts an allocate of the requested type at the entry point of the
+// function that the builder is currently building. The insert point
+// of the builder is set to the same place after calling this function
+// as before.
+//
+// This can be useful to avoid e.g. executing an alloca every time
+// through a loop.
+llvm::AllocaInst* EmitAllocaAtFunctionEntry(llvm::Type* type,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder,
+ int alignment = 0);
+
+// As EmitAllocaAtFunctionEntry, but allocates element_count entries
+// intead of a single element.
+llvm::AllocaInst* EmitAllocaAtFunctionEntryWithCount(
+ llvm::Type* type, llvm::Value* element_count, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder, int alignment = 0);
+
+// Creates a basic block with the same context and funtion as for the
+// builder. Inserts at the end of the function if insert_before is
+// null.
+llvm::BasicBlock* CreateBasicBlock(llvm::BasicBlock* insert_before,
+ tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder);
+
+// Struct with data on a conditional branch in a diamond shape created
+// via EmitIfThenElse.
+struct LlvmIfData {
+ // The block that has the conditional branch.
+ llvm::BasicBlock* if_block;
+
+ // The block that is executed if the condition is true.
+ llvm::BasicBlock* true_block;
+
+ // The block that is executed if the condition is false.
+ llvm::BasicBlock* false_block;
+
+ // The block that follows after both the true_block and the
+ // false_block.
+ llvm::BasicBlock* after_block;
+};
+
+// Inserts a diamond-shaped if-then-else construct at the current
+// insertion point of the builder. This involves splitting the current
+// block into two blocks, at the insertion point, and introducing a
+// true-block and a false-block that connect the two split pieces. The
+// true-block is executed if the condition parameter evaluates to true
+// and otherwise the false-block is executed. If `emit_else` is false,
+// it jumps to the after-block rather than the false-block if the
+// condition is false, and the returned `false_block` is null.
+//
+// Currently the insertion point of the builder must be a well-formed
+// block with a terminator. If you need to use this for a
+// non-terminated block, just make the function able to do that too.
+LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name,
+ llvm::IRBuilder<>* ir_builder, bool emit_else = true);
+
+// Emits a compare operation between "lhs" and "rhs" with the given predicate,
+// and then converts the result to i8 so that it is addressable.
+llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate,
+ llvm::Value* lhs, llvm::Value* rhs,
+ llvm::IRBuilder<>* ir_builder);
+
+// Emits a call that logs the given value with the given tag as a prefix.
+// The provided tag and value are passed to a runtime logging call that is
+// embedded in this translation unit when the emitted code is executed.
+//
+// This can be very useful for debugging generated programs in short order when
+// developing new generated routines.
+//
+// Precondition: value must be an int64.
+// Precondition: tag must be a stable pointer for the lifetime of the generated
+// program (the constant pointer is burned in to the program).
+void EmitLogging(const char* tag, llvm::Value* value,
+ llvm::IRBuilder<>* ir_builder);
+
+// Adds TBAA metadata to a load or store instruction using the given shape as
+// it's type. The is_pointer_to parameter is used to indicate whether or not
+// this instruction loads or stores a pointer to an array.
+void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape,
+ bool is_pointer_to);
+
+// Adds alignment metadata to a load instruction using the given alignment.
+// The alignment refers to the result of the load, not the load itself.
+void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment);
+
+// Adds dereferenceable metadata to a load instruction using the given
+// the number of dereferenceable bytes.
+// Dereferenceable refers to the result of the load, not the load itself.
+void SetDereferenceableMetadataForLoad(llvm::LoadInst* load,
+ uint64_t dereferenceable_bytes);
+
+// Tells LLVM `inst >= lower && inst < upper`. Returns `inst` for convenience.
+llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
+ llvm::Instruction* inst);
+
+void SetToFirstInsertPoint(llvm::BasicBlock* blk, llvm::IRBuilder<>* builder);
+
+// Create a bitwise rotation of `rotand` by `rotor`.
+llvm::Value* CreateRor(llvm::Value* rotand, llvm::Value* rotor,
+ llvm::IRBuilder<>* builder);
+
+// Returns the number of bytes within the shape.
+int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout);
+
+// Set values in the given FastMathFlags struct according to our XLA flags.
+void SetFastMathFlags(llvm::FastMathFlags* flags);
+
+// Set values in the given TargetOptions struct according to our XLA flags.
+void SetTargetOptions(llvm::TargetOptions* options);
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LLVM_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
new file mode 100644
index 0000000000..9a128b2aa6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
+ llvm::IRBuilder<>* ir_builder)
+ : body_emitter_(body_emitter), shape_(shape), ir_builder_(ir_builder) {}
+
+LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
+ const IrArray& target_array,
+ llvm::IRBuilder<>* ir_builder)
+ : body_emitter_([=](const llvm_ir::IrArray::Index array_index)
+ -> ::tensorflow::Status {
+ // Convert target_element_generator to a BodyEmitter.
+ TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
+ target_element_generator(array_index));
+ target_array.EmitWriteArrayElement(array_index, target_element,
+ ir_builder);
+ return tensorflow::Status::OK();
+ }),
+ shape_(target_array.GetShape()),
+ ir_builder_(ir_builder) {}
+
+IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock() {
+ CHECK(!ShapeUtil::IsTuple(shape_));
+ if (ShapeUtil::IsScalar(shape_)) {
+ // No loop needed, so set exit_bb_ to nullptr.
+ exit_bb_ = nullptr;
+ return IrArray::Index();
+ }
+
+ // Create loop nest with one for-loop for each dimension of the target shape.
+ // Loops are added from outermost to innermost order with the ForLoopNest
+ // class so emit loops in order from most-major dimension down to most-minor
+ // dimension (of the target shape).
+ ForLoopNest loop_nest(ir_builder_);
+ IrArray::Index array_index(shape_.dimensions_size());
+ for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) {
+ int64 dimension = shape_.layout().minor_to_major(i);
+ std::unique_ptr<ForLoop> loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/shape_.dimensions(dimension),
+ /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
+ array_index[dimension] = loop->GetIndVarValue();
+ }
+
+ // Set IR builder insertion point to the loop body basic block of the
+ // innermost loop.
+ llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock();
+ ir_builder_->SetInsertPoint(innermost_body_bb,
+ innermost_body_bb->getFirstInsertionPt());
+
+ // Set exit_bb_ to the exit block of the loop nest.
+ exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
+ CHECK_NOTNULL(exit_bb_);
+
+ return array_index;
+}
+
+tensorflow::Status LoopEmitter::EmitLoop() {
+ IrArray::Index array_index = EmitIndexAndSetExitBasicBlock();
+ TF_RETURN_IF_ERROR(body_emitter_(array_index));
+
+ // Set the insertion point of ir_builder_ to the loop exit, so that
+ // code emitted for later instructions will be correctly placed.
+ if (exit_bb_ != nullptr) {
+ ir_builder_->SetInsertPoint(exit_bb_);
+ }
+ return tensorflow::Status::OK();
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
new file mode 100644
index 0000000000..08171e9e9d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h
@@ -0,0 +1,79 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LOOP_EMITTER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LOOP_EMITTER_H_
+
+#include <functional>
+
+#include "external/llvm/include/llvm/IR/BasicBlock.h"
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// A function type for emitting code that generates an element in the target
+// array. The function gets a multi-dimensional index as its only input. This
+// index specifies the target element for which a value needs to be computed.
+// The function has to emit code to compute this value and return the resulting
+// llvm::Value*.
+using ElementGenerator =
+ std::function<StatusOr<llvm::Value*>(const IrArray::Index& index)>;
+
+// Emits a loop for every element in the given shape.
+class LoopEmitter {
+ public:
+ using BodyEmitter =
+ std::function<tensorflow::Status(const IrArray::Index& index)>;
+
+ LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
+ llvm::IRBuilder<>* ir_builder);
+ // Constructs a LoopEmitter from an element generator that generates each
+ // element of the given target array.
+ LoopEmitter(const ElementGenerator& target_element_generator,
+ const IrArray& target_array, llvm::IRBuilder<>* ir_builder);
+ LoopEmitter(const LoopEmitter&) = delete;
+ LoopEmitter& operator=(const LoopEmitter&) = delete;
+ virtual ~LoopEmitter() = default;
+
+ // Emits a loop nest (with a yet-to-be-filled loop body) that iterates through
+ // every element in the given shape. Returns the multi-dimensional index that
+ // specifies the element.
+ virtual IrArray::Index EmitIndexAndSetExitBasicBlock();
+
+ // Emits a complete loop nest for every element in the given shape.
+ tensorflow::Status EmitLoop();
+
+ protected:
+ // An IR emitter that generates the loop body.
+ BodyEmitter body_emitter_;
+
+ // The shape that the emitted loop iterates through.
+ Shape shape_;
+
+ // Points to the exit block of the emitted loop. If the given shape is
+ // scalar, no loops are emitted and exit_bb_ is nullptr in that case.
+ llvm::BasicBlock* exit_bb_;
+
+ llvm::IRBuilder<>* ir_builder_;
+};
+
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_LOOP_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc
new file mode 100644
index 0000000000..e01d25d250
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
+
+#include <stddef.h>
+#include <string>
+#include <vector>
+
+#include "external/llvm/include/llvm/IR/Instructions.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace llvm_ir {
+
+void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
+ llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) {
+ CHECK(ShapeUtil::IsScalar(pred.GetShape()));
+
+ llvm::LoadInst* pred_value =
+ ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value");
+ llvm::Value* pred_cond = ir_builder->CreateICmpNE(
+ pred_value,
+ llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0),
+ "boolean_predicate");
+
+ VLOG(2) << "HandleSelect for tuple:";
+ VLOG(2) << " pred_value: " << DumpToString(*pred_value);
+ VLOG(2) << " pred_cond: " << DumpToString(*pred_cond);
+
+ for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) {
+ std::vector<llvm::Value*> element_index = {ir_builder->getInt64(0),
+ ir_builder->getInt64(i)};
+ llvm::Value* on_true_element_address =
+ ir_builder->CreateInBoundsGEP(on_true, element_index);
+ llvm::Value* on_true_element = ir_builder->CreateLoad(
+ on_true_element_address,
+ tensorflow::strings::Printf("on_true_element_%d", i).c_str());
+ llvm::Value* on_false_element_address =
+ ir_builder->CreateInBoundsGEP(on_false, element_index);
+ llvm::Value* on_false_element = ir_builder->CreateLoad(
+ on_false_element_address,
+ tensorflow::strings::Printf("on_false_element_%d", i).c_str());
+
+ llvm::Value* output_element_address =
+ ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index);
+ ir_builder->CreateStore(
+ ir_builder->CreateSelect(
+ pred_cond, on_true_element, on_false_element,
+ tensorflow::strings::Printf("select_output_element_%d", i).c_str()),
+ output_element_address);
+ }
+}
+
+void EmitTuple(IrArray tuple,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ llvm::IRBuilder<>* ir_builder) {
+ for (size_t i = 0; i < operands.size(); ++i) {
+ ir_builder->CreateStore(
+ ir_builder->CreatePointerCast(operands[i],
+ PrimitiveTypeToIrType(TUPLE, ir_builder)),
+ ir_builder->CreateInBoundsGEP(
+ tuple.GetBasePointer(),
+ {ir_builder->getInt64(0), ir_builder->getInt64(i)}));
+ }
+}
+
+llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
+ int alignment, llvm::Value* operand,
+ llvm::IRBuilder<>* ir_builder) {
+ llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP(
+ operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)});
+ llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr);
+ SetTbaaForInstruction(src_buffer, target_shape, /*is_pointer_to=*/true);
+ SetAlignmentMetadataForLoad(src_buffer, alignment);
+ llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder);
+ llvm::Value* ret_val =
+ ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo());
+ return ret_val;
+}
+
+} // namespace llvm_ir
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h
new file mode 100644
index 0000000000..af4063c340
--- /dev/null
+++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h
@@ -0,0 +1,79 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
+
+#include "external/llvm/include/llvm/IR/IRBuilder.h"
+#include "external/llvm/include/llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace llvm_ir {
+
+// Selection among tuples is special in how it's lowered, because a tuple is not
+// an HLO array.
+//
+// tuple_on_true tuple_on_false
+// | |
+// V V
+// ------------------------ ------------------------
+// | address of element 0 | | address of element 0 |
+// |----------------------| |----------------------|
+// | address of element 1 | | address of element 1 |
+// |----------------------| |----------------------|
+// | address of element 2 | | address of element 2 |
+// ------------------------ ------------------------
+// \ /
+// \ /
+// ----------
+// pred ---------> | select |
+// ----------
+// |
+// V
+// output ----> ------------------------
+// | address of element 0 |
+// |----------------------|
+// | address of element 1 |
+// |----------------------|
+// | address of element 2 |
+// ------------------------
+//
+// Only the addresses are copied to the output. For each element, we emit a copy
+// of the address from the corresponding element in either
+// tuple_on_true or tuple_on_false:
+// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i]
+void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true,
+ llvm::Value* on_false, llvm::IRBuilder<>* ir_builder);
+
+// A tuple is an array of pointers, one for each operand. Each pointer points to
+// the output buffer of its corresponding operand.
+void EmitTuple(IrArray tuple,
+ tensorflow::gtl::ArraySlice<llvm::Value*> operands,
+ llvm::IRBuilder<>* ir_builder);
+
+// A tuple is an array of pointers, one for each operand. Each pointer points to
+// the output buffer of its corresponding operand. A GetTupleElement instruction
+// forwards the pointer to underlying tuple element buffer at the given index.
+// Returns an llvm value representing a pointer to the tuple element buffer.
+llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index,
+ int alignment, llvm::Value* operand,
+ llvm::IRBuilder<>* ir_builder);
+} // namespace llvm_ir
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
new file mode 100644
index 0000000000..38465e37e7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -0,0 +1,543 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/local_service.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/user_computation.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+LocalExecuteOptions& LocalExecuteOptions::set_platform(
+ perftools::gputools::Platform* platform) {
+ platform_ = platform;
+ return *this;
+}
+
+perftools::gputools::Platform* LocalExecuteOptions::platform() const {
+ return platform_;
+}
+
+LocalExecuteOptions& LocalExecuteOptions::set_device_ordinal(
+ int device_ordinal) {
+ device_ordinal_ = device_ordinal;
+ return *this;
+}
+
+int LocalExecuteOptions::device_ordinal() const { return device_ordinal_; }
+
+LocalExecuteOptions& LocalExecuteOptions::set_allocator(
+ DeviceMemoryAllocator* allocator) {
+ allocator_ = allocator;
+ return *this;
+}
+
+DeviceMemoryAllocator* LocalExecuteOptions::allocator() const {
+ return allocator_;
+}
+
+LocalExecuteOptions& LocalExecuteOptions::set_stream(
+ perftools::gputools::Stream* stream) {
+ stream_ = stream;
+ return *this;
+}
+
+perftools::gputools::Stream* LocalExecuteOptions::stream() const {
+ return stream_;
+}
+
+LocalExecuteOptions& LocalExecuteOptions::set_execution_profile(
+ ExecutionProfile* profile) {
+ profile_ = profile;
+ return *this;
+}
+
+ExecutionProfile* LocalExecuteOptions::execution_profile() const {
+ return profile_;
+}
+
+LocalExecuteOptions& LocalExecuteOptions::set_result_layout(
+ const Shape& shape_with_layout) {
+ has_result_shape_with_layout_ = true;
+ result_shape_with_layout_ = shape_with_layout;
+ return *this;
+}
+
+const Shape* LocalExecuteOptions::result_layout() const {
+ return has_result_shape_with_layout_ ? &result_shape_with_layout_ : nullptr;
+}
+
+/* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
+ perftools::gputools::Platform* platform) {
+ ServiceOptions default_options;
+ default_options.set_platform(platform);
+ return NewService(default_options);
+}
+
+/* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
+ const ServiceOptions& options) {
+ perftools::gputools::Platform* platform = options.platform();
+ if (platform == nullptr) {
+ TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Backend> backend,
+ Backend::CreateBackend(platform, options.number_of_replicas()));
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
+ CreateComputeConstantBackend());
+ std::unique_ptr<LocalService> service(new LocalService(
+ std::move(backend), std::move(compute_constant_backend)));
+ return std::move(service);
+}
+
+LocalService::LocalService(std::unique_ptr<Backend> execute_backend,
+ std::unique_ptr<Backend> compute_constant_backend)
+ : Service(std::move(execute_backend), std::move(compute_constant_backend)) {
+ runs_in_client_process_ = true;
+}
+
+tensorflow::Status LocalService::ResolveArguments(
+ const tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
+ int device_ordinal,
+ std::vector<perftools::gputools::DeviceMemoryBase>* argument_ptrs) {
+ TF_ASSIGN_OR_RETURN(std::vector<const Allocation*> arg_allocations,
+ ResolveAndValidateArguments(
+ arguments, execute_backend_.get(), device_ordinal));
+ argument_ptrs->resize(arg_allocations.size());
+ for (int i = 0; i < arguments.size(); ++i) {
+ const Allocation& allocation = *arg_allocations[i];
+ (*argument_ptrs)[i] = allocation.device_memory();
+ }
+ return tensorflow::Status::OK();
+}
+
+namespace {
+// Returns the space required to allocate a shape. If
+// allocate_space_for_deep_copy the space includes all sub-buffers of
+// a tuple.
+int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy,
+ TransferManager* transfer_manager) {
+ int64 size = 0;
+ // TODO(b/33492279) remove once no devices represent result tuples as
+ // contiguous buffers.
+ if (allocate_space_for_deep_copy) {
+ TF_CHECK_OK(ShapeUtil::ForEachSubshape(
+ shape, [&size, transfer_manager](const Shape& subshape,
+ const ShapeIndex& /*index*/) {
+ size += transfer_manager->GetByteSizeRequirement(subshape);
+ return tensorflow::Status::OK();
+ }));
+ }
+ return size;
+}
+} // namespace
+
+StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
+ const Shape& shape, int device_ordinal, bool allocate_space_for_deep_copy) {
+ int64 allocation_size = RequiredSpace(shape, allocate_space_for_deep_copy,
+ execute_backend_->transfer_manager());
+
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
+ execute_backend_->memory_allocator()->Allocate(
+ device_ordinal, allocation_size));
+
+ return allocation_tracker_.Register(
+ execute_backend_.get(), device_ordinal, allocation, shape,
+ tensorflow::strings::StrCat("AllocateBufferOnDevice of size ",
+ allocation_size));
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocally(
+ const ComputationHandle& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options) {
+ return ExecuteLocallyInternal(computation, arguments, options,
+ /*preallocated_result_buffer=*/nullptr);
+}
+
+tensorflow::Status LocalService::ExecuteLocally(
+ const ComputationHandle& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result_buffer) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<ShapedBuffer> null_buffer,
+ ExecuteLocallyInternal(computation, arguments, options, result_buffer));
+ // Because the result is written into result_buffer, a null ShapedBuffer
+ // pointer should have been returned.
+ CHECK_EQ(nullptr, null_buffer.get());
+ return tensorflow::Status::OK();
+}
+
+StatusOr<std::unique_ptr<AotCompilationResult>>
+LocalService::CompileAheadOfTime(
+ const ComputationHandle& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const Shape& result_layout, const AotCompilationOptions& options) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(computation));
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> hlo_module,
+ computation_tracker_.BuildHloModule(versioned_handle,
+ /*include_unused_parameters=*/true));
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+
+ auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
+ auto* computation_layout = module_config->mutable_entry_computation_layout();
+ for (int i = 0; i < argument_layouts.size(); ++i) {
+ const Shape& argument_layout = *argument_layouts[i];
+ if (ShapeUtil::IsTuple(argument_layout)) {
+ return Unimplemented("tuple arguments not supported yet");
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ argument_layout));
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ result_layout));
+
+ return execute_backend_->compiler()
+ ->CompileAheadOfTime(std::move(hlo_module), std::move(module_config),
+ MakeHloDumper(), options)
+ .ConsumeValueOrDie();
+}
+
+tensorflow::Status LocalService::ValidateExecuteOptions(
+ const ProgramShape& program_shape,
+ tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const LocalExecuteOptions& options,
+ const ShapedBuffer* preallocated_result_buffer) {
+ if (argument_layouts.size() != program_shape.parameters_size()) {
+ return InvalidArgument(
+ "invalid number of arguments for computation: expected %d, got %zu",
+ program_shape.parameters_size(), argument_layouts.size());
+ }
+
+ if (options.stream()) {
+ if (!options.stream()->ok()) {
+ return InvalidArgument("stream is uninitialized or in an error state");
+ }
+
+ // Check stream matches service platform.
+ const se::Platform* stream_platform =
+ options.stream()->parent()->platform();
+ if (stream_platform != execute_backend_->platform()) {
+ return InvalidArgument(
+ "stream is for platform %s, but service targets platform %s",
+ stream_platform->Name().c_str(),
+ execute_backend_->platform()->Name().c_str());
+ }
+
+ // Cannot specify platform or device_ordinal with a stream. The stream
+ // determines these values.
+ if (options.device_ordinal() >= 0) {
+ return InvalidArgument(
+ "cannot set both device ordinal and stream options in "
+ "LocalExecuteOptions; the stream determines the device ordinal");
+ }
+ if (options.platform()) {
+ return InvalidArgument(
+ "cannot set both platform and stream options in "
+ "LocalExecuteOptions; the stream determines the platform");
+ }
+ }
+ if (options.platform() &&
+ options.platform() != execute_backend_->platform()) {
+ return InvalidArgument(
+ "service platform (%s) does not match platform set in "
+ "LocalExecuteOptions (%s)",
+ execute_backend_->platform()->Name().c_str(),
+ options.platform()->Name().c_str());
+ }
+
+ // TODO(cwhipkey): validate the thread pool provided?
+
+ if (!options.allocator()) {
+ return InvalidArgument("an allocator must be provided to ExecuteLocally");
+ }
+
+ if (options.allocator()->platform() != execute_backend_->platform()) {
+ return InvalidArgument(
+ "allocator platform (%s) does not match service platform (%s)",
+ options.allocator()->platform()->Name().c_str(),
+ execute_backend_->platform()->Name().c_str());
+ }
+
+ if (preallocated_result_buffer != nullptr) {
+ if (options.result_layout()) {
+ return InvalidArgument(
+ "cannot set both result ShapedBuffer and result layout; the result "
+ "ShapedBuffer determines the result layout");
+ }
+ if (!ShapeUtil::Compatible(preallocated_result_buffer->shape(),
+ program_shape.result())) {
+ return InvalidArgument(
+ "result ShapedBuffer of shape %s not compatible with computation "
+ "result shape %s",
+ ShapeUtil::HumanString(preallocated_result_buffer->shape()).c_str(),
+ ShapeUtil::HumanString(program_shape.result()).c_str());
+ }
+ }
+ if (options.result_layout()) {
+ TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(*options.result_layout(),
+ program_shape.result()));
+ }
+
+ // Check that all argument layouts are valid and the right shape.
+ for (int i = 0; i < argument_layouts.size(); ++i) {
+ const Shape& argument_shape = *argument_layouts[i];
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
+ if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
+ return InvalidArgument(
+ "invalid argument shape for argument %d, expected %s, got %s", i,
+ ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(argument_shape).c_str());
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
+ const ComputationHandle& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options,
+ ShapedBuffer* preallocated_result_buffer) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(computation));
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+
+ // Determine device ordinal the computation will run on.
+ int device_ordinal;
+ if (options.device_ordinal() >= 0) {
+ device_ordinal = options.device_ordinal();
+ } else if (options.stream()) {
+ device_ordinal = options.stream()->parent()->device_ordinal();
+ } else {
+ device_ordinal = execute_backend_->default_device_ordinal();
+ }
+
+ // Check that all arguments are on the right platform and device ordinal.
+ std::vector<const Shape*> argument_layouts(arguments.size());
+ for (int i = 0; i < arguments.size(); ++i) {
+ auto argument = arguments[i];
+ if (argument->platform() != execute_backend_->platform() ||
+ argument->device_ordinal() != device_ordinal) {
+ return InvalidArgument(
+ "computation to run on device %s but argument %d is on "
+ "device %s:%d",
+ execute_backend_->device_name(device_ordinal).c_str(), i,
+ argument->platform()->Name().c_str(), argument->device_ordinal());
+ }
+ argument_layouts[i] = &argument->shape();
+ }
+
+ TF_RETURN_IF_ERROR(ValidateExecuteOptions(
+ *program_shape, argument_layouts, options, preallocated_result_buffer));
+
+ // Construct computation layout from the argument layouts.
+ auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
+ module_config->set_has_hybrid_result(true);
+ module_config->set_replica_count(execute_backend_->Replicas().size());
+ std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers;
+ auto* computation_layout = module_config->mutable_entry_computation_layout();
+ for (int i = 0; i < arguments.size(); ++i) {
+ const ShapedBuffer* argument = arguments[i];
+ if (ShapeUtil::IsTuple(argument->shape())) {
+ return Unimplemented("tuple arguments not supported yet");
+ }
+ argument_buffers.push_back(argument->buffer(/*index=*/{}));
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ argument->shape()));
+ }
+ if (options.result_layout()) {
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ *options.result_layout()));
+ } else if (preallocated_result_buffer != nullptr) {
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ preallocated_result_buffer->shape()));
+ } else {
+ computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ }
+
+ ExecutableRunOptions run_options;
+ run_options.set_allocator(options.allocator());
+ run_options.set_inter_op_thread_pool(
+ execute_backend_->inter_op_thread_pool());
+ run_options.set_intra_op_thread_pool(
+ execute_backend_->eigen_intra_op_thread_pool_device());
+
+ // "acquired_stream" owns the stream used for execution if no stream is given.
+ std::unique_ptr<se::Stream> acquired_stream;
+ if (options.stream()) {
+ run_options.set_stream(options.stream());
+ } else {
+ se::StreamExecutor* stream_executor;
+ if (options.device_ordinal() >= 0) {
+ TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor(
+ options.device_ordinal()));
+ } else {
+ stream_executor = execute_backend_->default_stream_executor();
+ }
+ TF_ASSIGN_OR_RETURN(acquired_stream,
+ execute_backend_->AcquireStream(stream_executor));
+ run_options.set_stream(acquired_stream.get());
+ }
+ auto stream_releaser =
+ ::tensorflow::gtl::MakeCleanup([this, &acquired_stream]() {
+ if (acquired_stream != nullptr) {
+ execute_backend_->ReleaseStream(std::move(acquired_stream));
+ }
+ });
+
+ ExecutionProfile* profile = options.execution_profile();
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<Executable> executable,
+ BuildAndCacheExecutable(versioned_handle, std::move(module_config),
+ argument_buffers, execute_backend_.get(),
+ run_options.stream()->parent(), profile));
+
+ if (preallocated_result_buffer == nullptr) {
+ return Service::ExecuteOnStreamWrapper<
+ StatusOr<std::unique_ptr<ShapedBuffer>>>(
+ executable.get(), &run_options, profile,
+ [&arguments](Executable* executable,
+ const ExecutableRunOptions* run_options,
+ HloExecutionProfile* hlo_execution_profile) {
+ return executable->ExecuteOnStream(run_options, arguments,
+ hlo_execution_profile);
+ });
+ } else {
+ TF_RETURN_IF_ERROR(Service::ExecuteOnStreamWrapper<tensorflow::Status>(
+ executable.get(), &run_options, profile,
+ [&arguments, preallocated_result_buffer](
+ Executable* executable, const ExecutableRunOptions* run_options,
+ HloExecutionProfile* hlo_execution_profile) {
+ return executable->ExecuteOnStream(run_options, arguments,
+ preallocated_result_buffer,
+ hlo_execution_profile);
+ }));
+ // To satisfy the return value type, Return a null ShapedBuffer pointer.
+ return std::unique_ptr<ShapedBuffer>();
+ }
+}
+
+StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
+ const ComputationHandle& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const Shape* result_layout, int device_ordinal, bool has_hybrid_result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(computation));
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+
+ // Validate incoming layouts.
+ if (argument_layouts.size() != program_shape->parameters_size()) {
+ return InvalidArgument(
+ "invalid number of arguments for computation: expected %d, got %zu",
+ program_shape->parameters_size(), argument_layouts.size());
+ }
+ for (int i = 0; i < argument_layouts.size(); ++i) {
+ const Shape& argument_shape = *argument_layouts[i];
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
+ if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) {
+ return InvalidArgument(
+ "invalid argument shape for argument %d, expected %s, got %s", i,
+ ShapeUtil::HumanString(program_shape->parameters(i)).c_str(),
+ ShapeUtil::HumanString(argument_shape).c_str());
+ }
+ }
+ if (result_layout != nullptr) {
+ TF_RETURN_IF_ERROR(
+ ValidateResultShapeWithLayout(*result_layout, program_shape->result()));
+ }
+
+ // Construct computation layout from the argument layouts.
+ auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
+ module_config->set_has_hybrid_result(has_hybrid_result);
+ module_config->set_replica_count(execute_backend_->Replicas().size());
+ auto* computation_layout = module_config->mutable_entry_computation_layout();
+ for (int i = 0; i < argument_layouts.size(); ++i) {
+ const Shape& shape = *argument_layouts[i];
+ if (ShapeUtil::IsTuple(shape)) {
+ return Unimplemented("tuple arguments not supported yet");
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ shape));
+ }
+ if (result_layout != nullptr) {
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ *result_layout));
+ } else {
+ computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ }
+
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
+ execute_backend_->stream_executor(device_ordinal));
+
+ std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers(
+ argument_layouts.size());
+ return BuildExecutable(versioned_handle, std::move(module_config),
+ /*executable_for_compute_constant=*/false,
+ argument_buffers, execute_backend_.get(), executor);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
new file mode 100644
index 0000000000..3e160a0201
--- /dev/null
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -0,0 +1,185 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Computation execution options which may be set by the client when executing
+// locally (via LocalClient::ExecuteLocally).
+class LocalExecuteOptions {
+ public:
+ // Specifies the allocator to use during execution. Execution will fail if no
+ // allocator is provided.
+ LocalExecuteOptions& set_allocator(DeviceMemoryAllocator* allocator);
+ DeviceMemoryAllocator* allocator() const;
+
+ // If set, this is the platform to run the computation on. This must match
+ // the underlying platform of the service. A value of nullptr means the
+ // platform is not set.
+ // TODO(b/28616830): Support multiple platforms.
+ LocalExecuteOptions& set_platform(perftools::gputools::Platform* platform);
+ perftools::gputools::Platform* platform() const;
+
+ // If set, this is the device to run the computation on. Valid device_ordinal
+ // values are: 0 to # of devices - 1. These values are identical to the
+ // device ordinal values used by StreamExecutor. A value of < 0 means the
+ // ordinal is not set.
+ LocalExecuteOptions& set_device_ordinal(int device_ordinal);
+ int device_ordinal() const;
+
+ // If set, this is the stream to run the computation on. The platform of the
+ // stream must match the service's platform. The device ordinal
+ // option (if set) must match the stream's device. A value of nullptr means
+ // the stream is not set.
+ LocalExecuteOptions& set_stream(perftools::gputools::Stream* stream);
+ perftools::gputools::Stream* stream() const;
+
+ // If set, collect profile information during execution and fill the given
+ // ExecutionProfile object with the profile data. A value of nullptr means
+ // the profile is not set.
+ LocalExecuteOptions& set_execution_profile(ExecutionProfile* profile);
+ ExecutionProfile* execution_profile() const;
+
+ // If set, this specifies the layout of the result of the computation. If not
+ // set, the service will chose the layout of the result. A Shape is used to
+ // store the layout to accomodate tuple result shapes. A value of nullptr
+ // means the shape is not set.
+ LocalExecuteOptions& set_result_layout(const Shape& shape_with_layout);
+ const Shape* result_layout() const;
+
+ private:
+ DeviceMemoryAllocator* allocator_ = nullptr;
+ perftools::gputools::Platform* platform_ = nullptr;
+ int device_ordinal_ = -1;
+ perftools::gputools::Stream* stream_ = nullptr;
+ ExecutionProfile* profile_ = nullptr;
+
+ bool has_result_shape_with_layout_ = false;
+ Shape result_shape_with_layout_;
+};
+
+// Service implementation that extends the XLA Service to leverage running
+// in the same process as the client.
+class LocalService : public Service {
+ public:
+ // Factory for creating a LocalService. The parameter platform is the platform
+ // that the service should target. If platform is null then the default
+ // platform is used.
+ static StatusOr<std::unique_ptr<LocalService>> NewService(
+ perftools::gputools::Platform* platform);
+ static StatusOr<std::unique_ptr<LocalService>> NewService(
+ const ServiceOptions& options);
+
+ // For an array of arguments, validate that each is placed on the
+ // specified device_ordinal, and return the DeviceMemoryBase
+ // corresponding to each argument.
+ tensorflow::Status ResolveArguments(
+ const tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
+ int device_ordinal,
+ std::vector<perftools::gputools::DeviceMemoryBase>* argument_ptrs);
+
+ // Return a handle to a buffer large enough to hold shape, allocated
+ // on device_ordinal. If allocate_space_for_deep_copy, the buffer is
+ // large enough to hold all sub-buffers of a tuple shape, otherwise
+ // it is only as large as the top-level tuple pointer array.
+ StatusOr<GlobalDataHandle> AllocateBufferOnDevice(
+ const Shape& shape, int device_ordinal,
+ bool allocate_space_for_deep_copy);
+
+ // Execute the given computation with the given arguments and options with
+ // zero-copy data handling of arguments and result.
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteLocally(
+ const ComputationHandle& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options);
+
+ // Overload which writes the result into the given ShapedBuffer "result".
+ // Due to aliasing, not all buffers which comprise "result" may be utilized
+ // in the computation and thus be uninitialized. The |ShapedBuffer::buffer|
+ // or |ShapedBuffer::mutable_buffer| methods should be used to map an index to
+ // the initialized buffer.
+ //
+ // For example:
+ // Let 'result' be a ShapedBuffer holding a tuple with the same element,
+ // 'x', twice: (x, x). It is incorrect to assume that the second buffer
+ // which comprises 'result' is initialized. Instead, a mapping has been
+ // added to 'result' which can be used to recover the correct buffer.
+ // In this case, result->buffer({0}) should be used to extract the address of
+ // the first tuple element while result->buffer({1}) should be used for the
+ // second.
+ tensorflow::Status ExecuteLocally(
+ const ComputationHandle& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result_buffer);
+
+ // Compiles the computation for ahead-of-time execution. This is intended for
+ // use in static compilation. See |LocalClient::CompileAheadOfTime| for
+ // additional details.
+ StatusOr<std::unique_ptr<AotCompilationResult>> CompileAheadOfTime(
+ const ComputationHandle& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const Shape& result_layout, const AotCompilationOptions& Options);
+
+ // Builds an Executable with the given argument layouts and options. If
+ // result_layout is non-null, then the executable is compiled to produce a
+ // result of the given layout.
+ StatusOr<std::unique_ptr<Executable>> CompileExecutable(
+ const ComputationHandle& computation,
+ const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
+ const Shape* result_layout, int device_ordinal, bool has_hybrid_result);
+
+ private:
+ explicit LocalService(std::unique_ptr<Backend> backend,
+ std::unique_ptr<Backend> compute_constant_backend);
+ LocalService(const LocalService&) = delete;
+ void operator=(const LocalService&) = delete;
+
+ // Internal helper for executing a computation. If result_buffer is null then
+ // the result is returned as a ShapedBuffer. If result_buffer is non-null then
+ // the result is written into result_buffer and a null ShapedBuffer pointer is
+ // returned.
+ StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteLocallyInternal(
+ const ComputationHandle& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options,
+ ShapedBuffer* preallocated_result_buffer);
+
+ // Validates the given options and argument layouts and returns an appropriate
+ // error code.
+ tensorflow::Status ValidateExecuteOptions(
+ const ProgramShape& program_shape,
+ tensorflow::gtl::ArraySlice<const Shape*> arguments,
+ const LocalExecuteOptions& options,
+ const ShapedBuffer* preallocated_result_buffer);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOCAL_SERVICE_H_
diff --git a/tensorflow/compiler/xla/service/logical_buffer.cc b/tensorflow/compiler/xla/service/logical_buffer.cc
new file mode 100644
index 0000000000..00e4b35d15
--- /dev/null
+++ b/tensorflow/compiler/xla/service/logical_buffer.cc
@@ -0,0 +1,39 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+
+#include <ostream>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace xla {
+
+string LogicalBuffer::ToString() const {
+ return tensorflow::strings::StrCat(instruction_->name(), "[",
+ tensorflow::str_util::Join(index_, ","),
+ "](#", id_, ")");
+}
+
+std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer) {
+ out << buffer.ToString();
+ return out;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h
new file mode 100644
index 0000000000..21af9dcf66
--- /dev/null
+++ b/tensorflow/compiler/xla/service/logical_buffer.h
@@ -0,0 +1,153 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_
+
+#include <iosfwd>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+struct HashLogicalBuffer;
+
+// Class describing a contiguous sequence of elements (ie, C array) which form
+// the components of Shaped values in XLA. XLA arrays are trivially a
+// single LogicalBuffer. Tuple values are made up of more than one
+// LogicalBuffer: a LogicalBuffer for the pointers to elements, and a
+// LogicalBuffer for each child element.
+//
+// Every buffer is defined by a particular instruction and most instructions
+// define only a single buffer. Instructions which define a single buffer
+// include array-shaped instructions such as Add but also includes Tuple-shaped
+// instructions such as Tuple. The Tuple instruction defines a single buffer
+// which is a vector of pointers to the buffers containing the Tuple
+// instruction's operands. Though the result of the Tuple instruction includes
+// multiple buffers only the top-level buffer (the vector of pointers) is
+// defined by the Tuple instruction. The buffers containing the tuple elements
+// are defined by earlier instructions, usually the operands of the Tuple
+// instruction.
+//
+// Instructions which construct both the tuple *and* the tuple elements define
+// more than one buffer. This includes (at least) tuple-shaped Constant,
+// Parameter, Infeed and While instructions. The tuple-shaped instructions do
+// not assemble a tuple from existing buffers like the Tuple instruction does,
+// but rather define the entire tuple.
+//
+// Some instructions, such as Bitcast, define no buffers. These instructions
+// simply forward buffers from their operands.
+//
+// The LogicalBuffer object describes which HLO instruction defines a buffer and
+// where within that instruction's output shape the buffer is defined. The
+// location within the output shape is indicated by LogicalBuffer::index() which
+// is defined identically to the index used in
+// ShapeUtil::GetSubshape(). Examples:
+//
+// %add = Add(%foo, %bar)
+// %tuple_constant = Constant({1, {42, 43}})
+//
+// %add defines a single array-shaped buffer LogicalBuffer(%add, {}) which holds
+// the array result of the add operation. The nested-tuple-shaped
+// %tuple_constant defines 5 buffers described by the following LogicalBuffer
+// objects:
+//
+// LogicalBuffer(%tuple_constant, {}) // "Top-level" buffer: vector of
+// // pointers to LogicalBuffers at
+// // indices {0} and {1}
+// LogicalBuffer(%tuple_constant, {0}) // Holds value "1"
+// LogicalBuffer(%tuple_constant, {1}) // Holds nested tuple: vector of
+// // pointers to LogicalBuffers at
+// // indices {1, 0} and {1, 1}
+// LogicalBuffer(%tuple_constant, {1, 0}) // Holds value "42"
+// LogicalBuffer(%tuple_constant, {1, 1}) // Holds value "43"
+class LogicalBuffer {
+ public:
+ // Id is a unique identifier for the LogicalBuffer to facilitate efficient
+ // collections of LogicalBuffers with stable iteration order.
+ // LogicalBuffers are typically created and accessed through
+ // TuplePointsToAnalysis, and points-to analysis assigns each LogicalBuffer a
+ // unique value.
+ using Id = int64;
+
+ LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id)
+ : instruction_(instruction), index_(index), id_(id) {}
+
+ Id id() const { return id_; }
+
+ // Return the instruction that defines the buffer.
+ HloInstruction* instruction() const { return instruction_; }
+
+ // Return the index within the output of the instruction where the buffer is
+ // defined. Index used defined as in ShapeUtil::GetSubshape()
+ const ShapeIndex& index() const { return index_; }
+
+ // Return the shape of the buffer. This reference points into the shape field
+ // of the instruction defining the buffer. Therefore, the returned shape will
+ // contain the layout of instruction, if any.
+ const Shape& shape() const {
+ return ShapeUtil::GetSubshape(instruction_->shape(), index_);
+ }
+
+ // Returns true if this buffer is the top-level output buffer of the defining
+ // HLO instruction. This is equivalent to index == {}.
+ bool IsTopLevel() const { return index_.empty(); }
+
+ // Whether this buffer contains a tuple.
+ bool IsTuple() const { return ShapeUtil::IsTuple(shape()); }
+
+ // operator< is required for std::set.
+ bool operator<(const LogicalBuffer& other) const { return id_ < other.id_; }
+
+ // Whether this buffer contains an array.
+ bool IsArray() const { return ShapeUtil::IsArray(shape()); }
+
+ string ToString() const;
+
+ private:
+ friend struct HashLogicalBuffer;
+ HloInstruction* instruction_;
+ ShapeIndex index_;
+ Id id_;
+
+ // Similar to HLO constructs (HloInstruction, etc), pointers are used for
+ // comparison to equality, so disable all copying.
+ TF_DISALLOW_COPY_AND_ASSIGN(LogicalBuffer);
+};
+
+struct HashLogicalBuffer {
+ size_t operator()(const LogicalBuffer& b) const {
+ std::hash<const HloInstruction*> hasher;
+ size_t h = hasher(b.instruction_);
+ for (int i = 0; i < b.index_.size(); i++) {
+ h += static_cast<size_t>(b.index_[i] << i);
+ }
+ return h;
+ }
+};
+
+std::ostream& operator<<(std::ostream& out, const LogicalBuffer& buffer);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_H_
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
new file mode 100644
index 0000000000..4014856b9b
--- /dev/null
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/name_uniquer.h"
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
+ string root = prefix.empty() ? "name" : prefix.ToString();
+ int* count = &(generated_names_[root]);
+ if (*count == 0) {
+ *count = 1;
+ return root;
+ } else {
+ tensorflow::strings::StrAppend(&root, separator_, *count);
+ (*count)++;
+ return root;
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
new file mode 100644
index 0000000000..b0944adbc1
--- /dev/null
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Simple stateful class that helps generate "unique" names. To use it, simply
+// call GetUniqueName as many times as needed. The names returned by
+// GetUniqueName are guaranteed to be distinct for this instance of the class.
+class NameUniquer {
+ public:
+ explicit NameUniquer(const string& separator = "__")
+ : separator_(separator) {}
+
+ // Get a unique name in a string, with an optional prefix for convenience.
+ string GetUniqueName(tensorflow::StringPiece prefix = "");
+
+ private:
+ // The string to use to separate the prefix of the name from the uniquing
+ // integer value.
+ string separator_;
+
+ // Map from name prefix to the number of names generated using that prefix
+ // so far.
+ std::unordered_map<string, int> generated_names_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_NAME_UNIQUER_H_
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
new file mode 100644
index 0000000000..116bd3f067
--- /dev/null
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -0,0 +1,166 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/platform_util.h"
+
+#include <algorithm>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+// Minimum supported CUDA compute capability is 3.5.
+constexpr int kMinCudaComputeCapabilityMajor = 3;
+constexpr int kMinCudaComputeCapabilityMinor = 5;
+
+/* static */ StatusOr<std::vector<se::Platform*>>
+PlatformUtil::GetSupportedPlatforms() {
+ se::MultiPlatformManager::PlatformMap platform_map;
+ se::port::Status platforms_status = se::MultiPlatformManager::WithPlatforms(
+ [&platform_map](se::MultiPlatformManager::PlatformMap* map) {
+ platform_map = *map;
+ return se::port::Status::OK();
+ });
+ if (platform_map.empty()) {
+ LOG(WARNING) << "no executor platforms available: platform map is empty";
+ }
+
+ // Gather all platforms which have an XLA compiler.
+ std::vector<se::Platform*> platforms;
+ for (auto& platform_pair : platform_map) {
+ auto* platform = platform_pair.second;
+ auto compiler_status = Compiler::GetForPlatform(platform);
+ if (compiler_status.ok()) {
+ if (platform->VisibleDeviceCount() > 0) {
+ LOG(INFO) << "platform " << platform->Name() << " present with "
+ << platform->VisibleDeviceCount() << " visible devices";
+ } else {
+ LOG(WARNING) << "platform " << platform->Name() << " present but no "
+ << "visible devices found";
+ }
+ // Note: currently we call zero device platforms "supported" on the basis
+ // that, if the platform support was linked in, it was probably intended
+ // to be used for execution, and this way we can flag an error.
+ //
+ // TODO(b/33730287) If we want an alternative version of this behavior we
+ // could add an --xla_fallback_to_host flag.
+ platforms.push_back(platform);
+ } else {
+ LOG(INFO) << "platform " << platform->Name() << " present but no "
+ << "XLA compiler available: "
+ << compiler_status.status().error_message();
+ }
+ }
+ return platforms;
+}
+
+/* static */ StatusOr<se::Platform*> PlatformUtil::GetDefaultPlatform() {
+ TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms());
+ if (platforms.empty()) {
+ return NotFound("no platforms found");
+ } else if (platforms.size() == 1) {
+ return platforms[0];
+ } else if (platforms.size() == 2) {
+ // In the service we always link the cpu backend for ComputeConstant. So if
+ // one of the two platforms is CPU then pick the other (non-cpu) platform as
+ // the default.
+ if (platforms[0]->id() == se::host::kHostPlatformId) {
+ return platforms[1];
+ } else if (platforms[1]->id() == se::host::kHostPlatformId) {
+ return platforms[0];
+ }
+ }
+
+ // Multiple platforms present and we can't pick a reasonable default.
+ auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); };
+ string platforms_string = tensorflow::str_util::Join(platforms, ", ", l);
+ return InvalidArgument(
+ "must specify platform because more than one platform found: %s",
+ platforms_string.c_str());
+}
+
+// Returns whether the device underlying the given StreamExecutor is supported
+// by XLA.
+static bool IsDeviceSupported(se::StreamExecutor* executor) {
+ const auto& description = executor->GetDeviceDescription();
+ if (executor->platform()->id() == se::cuda::kCudaPlatformId) {
+ // CUDA devices must have a minimum compute capability.
+ int major_version, minor_version;
+ if (description.cuda_compute_capability(&major_version, &minor_version)) {
+ if (major_version < kMinCudaComputeCapabilityMajor ||
+ (major_version == kMinCudaComputeCapabilityMajor &&
+ minor_version < kMinCudaComputeCapabilityMinor)) {
+ LOG(INFO) << "StreamExecutor cuda device ("
+ << executor->device_ordinal() << ") is of "
+ << "insufficient compute capability: "
+ << kMinCudaComputeCapabilityMajor << "."
+ << kMinCudaComputeCapabilityMinor << " required, "
+ << "device is " << major_version << "." << minor_version;
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+/* static */ StatusOr<std::vector<se::StreamExecutor*>>
+PlatformUtil::GetStreamExecutors(se::Platform* platform) {
+ int device_count = platform->VisibleDeviceCount();
+ if (device_count <= 0) {
+ return NotFound("no %s devices found", platform->Name().c_str());
+ }
+ if (platform->id() == se::host::kHostPlatformId) {
+ // On host "devices", StreamExecutor exports a device for each hardware
+ // thread. Because we parallelize a single computation across threads, it
+ // doesn't make sense to expose these as separate devices, so fix the number
+ // of devices to one.
+ device_count = 1;
+ }
+ std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
+ for (int i = 0; i < device_count; ++i) {
+ se::StreamExecutorConfig config;
+ config.ordinal = i;
+ auto executor_status = platform->GetExecutor(config);
+ if (executor_status.ok()) {
+ se::StreamExecutor* executor = executor_status.ValueOrDie();
+ if (IsDeviceSupported(executor)) {
+ stream_executors[i] = executor;
+ }
+ } else {
+ LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name()
+ << ":" << i << ": "
+ << executor_status.status().error_message();
+ }
+ }
+ if (std::all_of(stream_executors.begin(), stream_executors.end(),
+ [](se::StreamExecutor* s) { return s == nullptr; })) {
+ return InternalError("no supported devices found for platform %s",
+ platform->Name().c_str());
+ }
+ return stream_executors;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h
new file mode 100644
index 0000000000..fe0281a69a
--- /dev/null
+++ b/tensorflow/compiler/xla/service/platform_util.h
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Utilities for querying platforms and devices used by XLA.
+class PlatformUtil {
+ public:
+ // Returns the platforms present on the system and supported by XLA.
+ //
+ // Note that, even if a platform is present with zero devices, if we *do* have
+ // compilation support for it, it will be returned in this sequence.
+ static StatusOr<std::vector<perftools::gputools::Platform*>>
+ GetSupportedPlatforms();
+
+ // Convenience function which returns the default supported platform. If
+ // exactly one supported platform is present, then this platform is the
+ // default platform. If exactly two supported platforms are present and one
+ // platform is CPU (host) then the non-CPU platform is default. This logic is
+ // used because the XLA service always links in the CPU backend to run
+ // ComputeConstant, so if exactly one other platform is linked in, we assume
+ // the intent is to execute on that non-CPU platform. If none of these
+ // conditions are met the function returns an error.
+ static StatusOr<perftools::gputools::Platform*> GetDefaultPlatform();
+
+ // Returns a vector of StreamExecutors for the given platform. The vector is
+ // indexed by device ordinal (device numbering used by StreamExecutor). If an
+ // element is nullptr, then the device is present by not supported by XLA.
+ //
+ // If the platform has no visible devices, a not-found error is returned.
+ static StatusOr<std::vector<perftools::gputools::StreamExecutor*>>
+ GetStreamExecutors(perftools::gputools::Platform* platform);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(PlatformUtil);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_PLATFORM_UTIL_H_
diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc
new file mode 100644
index 0000000000..5625804c2e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/reshape_mover.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+
+#include <algorithm>
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace xla {
+
+namespace {
+
+// Returns whether `a` and `b` are equivalent for the purposes of this pass.
+bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) {
+ if (a->opcode() != b->opcode() ||
+ !ShapeUtil::SameDimensions(a->shape(), b->shape())) {
+ return false;
+ }
+ switch (a->opcode()) {
+ case HloOpcode::kTranspose:
+ return a->dimensions() == b->dimensions();
+ case HloOpcode::kReshape:
+ return ShapeUtil::SameDimensions(a->operand(0)->shape(),
+ b->operand(0)->shape());
+ default:
+ return false;
+ }
+}
+
+bool IsElementwiseOfEquivalentReshapesOrTransposes(
+ const HloInstruction* instruction) {
+ const std::vector<HloInstruction*>& operands = instruction->operands();
+ return instruction->IsElementwise() && instruction->operand_count() > 0 &&
+ std::all_of(operands.begin(), operands.end(),
+ [](const HloInstruction* instruction) {
+ // We require operand have no other users as otherwise
+ // this is not a clear win.
+ return 1 == instruction->users().size();
+ }) &&
+ // Check whether each operand beyond the first is equivalent to the
+ // first.
+ std::all_of(operands.begin(), operands.end(),
+ [&operands](const HloInstruction* operand) {
+ return AreEquivalentReshapes(operands[0], operand);
+ });
+}
+
+// Try to sink any reshape or transpose operands of `instruction` across it. We
+// do so if `instruction` is elementwise and all operands are equivalent
+// reshapes or transposes.
+bool TrySinkReshapeOrTranspose(HloComputation* computation,
+ HloInstruction* instruction) {
+ if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
+ std::vector<HloInstruction*> operands = instruction->operands();
+ auto old_reshape = operands[0];
+ for (size_t i = 0; i < operands.size(); ++i) {
+ operands[i] = operands[i]->mutable_operand(0);
+ }
+ auto new_elementwise =
+ computation->AddInstruction(instruction->CloneWithNewOperands(
+ // `instruction` may change the element type, e.g., from
+ // operands[0] -> reshape -> convert (`instruction`)
+ // to
+ // operands[0] -> convert' -> reshape'
+ //
+ // In this case, convert' should have the same element type as
+ // `convert` and the same dimensions as operands[0].
+ ShapeUtil::MakeShape(
+ instruction->shape().element_type(),
+ AsInt64Slice(operands[0]->shape().dimensions())),
+ operands));
+ std::unique_ptr<HloInstruction> new_reshape;
+ switch (old_reshape->opcode()) {
+ case HloOpcode::kReshape:
+ new_reshape = HloInstruction::CreateReshape(instruction->shape(),
+ new_elementwise);
+ break;
+ case HloOpcode::kTranspose:
+ new_reshape = HloInstruction::CreateTranspose(
+ instruction->shape(), new_elementwise, old_reshape->dimensions());
+ break;
+ default:
+ LOG(FATAL) << "Bad opcode";
+ }
+ computation->ReplaceWithNewInstruction(instruction, std::move(new_reshape));
+ return true;
+ }
+ return false;
+}
+
+} // namespace
+
+StatusOr<bool> ReshapeMover::Run(HloModule* module) {
+ return std::any_of(
+ module->computations().begin(), module->computations().end(),
+ [](const std::unique_ptr<HloComputation>& computation) {
+ std::list<HloInstruction*> postorder =
+ computation->MakeInstructionPostOrder();
+ return std::any_of(postorder.begin(), postorder.end(),
+ [&computation](HloInstruction* instruction) {
+ return TrySinkReshapeOrTranspose(computation.get(),
+ instruction);
+ });
+ });
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h
new file mode 100644
index 0000000000..f7146b0ee3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/reshape_mover.h
@@ -0,0 +1,36 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
+
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// A pass which moves Reshapes and Transposes to let later passes combine them.
+// This now only moves them outputward across elementwise ops all whose operands
+// are equivalent Reshapes or Transposes, but in future could potentially move
+// them inputward also.
+class ReshapeMover : public HloPass {
+ public:
+ ReshapeMover() : HloPass("reshape motion") {}
+
+ StatusOr<bool> Run(HloModule* module) override;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_RESHAPE_MOVER_H_
diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc
new file mode 100644
index 0000000000..850295c726
--- /dev/null
+++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/reshape_mover.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace xla {
+namespace {
+using ReshapeMoverTest = HloTestBase;
+
+TEST_F(ReshapeMoverTest, ReshapesWithNonSameInputShapesNotMoved) {
+ auto root_shape = ShapeUtil::MakeShape(F32, {8, 7});
+ HloComputation::Builder builder(TestName());
+ auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0"));
+ auto param1 = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(F32, {1, 8, 7, 1}), "param0"));
+ auto reshape2 =
+ builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0));
+ auto reshape3 =
+ builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param1));
+ auto add4 = builder.AddInstruction(HloInstruction::CreateBinary(
+ root_shape, HloOpcode::kAdd, reshape2, reshape3));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+ EXPECT_EQ(add4, computation->root_instruction());
+ EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie());
+ EXPECT_EQ(add4, computation->root_instruction());
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
new file mode 100644
index 0000000000..847aea7888
--- /dev/null
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -0,0 +1,1428 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/service.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+namespace {
+
+// Copies the contents of an Allocation into a Literal proto.
+tensorflow::Status LiteralFromAllocation(const Allocation* allocation,
+ const Shape& literal_shape,
+ Literal* literal) {
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ allocation->backend()->stream_executor(allocation->device_ordinal()));
+ return allocation->backend()->transfer_manager()->TransferLiteralFromDevice(
+ executor, allocation->device_memory(), allocation->shape(), literal_shape,
+ literal);
+}
+
+// Records the arguments used to invoke a computation in a SessionModule
+// proto.
+tensorflow::Status RecordArguments(
+ const tensorflow::gtl::ArraySlice<const Allocation*> arg_allocations,
+ SessionModule* module) {
+ module->clear_arguments();
+ for (const Allocation* allocation : arg_allocations) {
+ TF_RETURN_IF_ERROR(LiteralFromAllocation(allocation, allocation->shape(),
+ module->add_arguments()));
+ }
+ return tensorflow::Status::OK();
+}
+
+// Records the result of a computation in a SessionModule proto.
+tensorflow::Status RecordResult(const Allocation* result_allocation,
+ SessionModule* module) {
+ module->clear_result();
+ return LiteralFromAllocation(result_allocation, result_allocation->shape(),
+ module->mutable_result());
+}
+
+} // namespace
+
+ServiceOptions& ServiceOptions::set_platform(
+ perftools::gputools::Platform* platform) {
+ platform_ = platform;
+ return *this;
+}
+
+perftools::gputools::Platform* ServiceOptions::platform() const {
+ return platform_;
+}
+
+ServiceOptions& ServiceOptions::set_number_of_replicas(int number_of_replicas) {
+ number_of_replicas_ = number_of_replicas;
+ return *this;
+}
+
+int ServiceOptions::number_of_replicas() const { return number_of_replicas_; }
+
+/* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
+ perftools::gputools::Platform* platform) {
+ ServiceOptions default_options;
+ default_options.set_platform(platform);
+ return NewService(default_options);
+}
+
+/* static */ StatusOr<std::unique_ptr<Service>> Service::NewService(
+ const ServiceOptions& options) {
+ perftools::gputools::Platform* platform = options.platform();
+ std::unique_ptr<Backend> execute_backend;
+ if (platform == nullptr) {
+ TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
+ }
+ TF_ASSIGN_OR_RETURN(
+ execute_backend,
+ Backend::CreateBackend(platform, options.number_of_replicas()));
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
+ CreateComputeConstantBackend());
+ std::unique_ptr<Service> service(new Service(
+ std::move(execute_backend), std::move(compute_constant_backend)));
+ return std::move(service);
+}
+
+/* static */ StatusOr<std::unique_ptr<Backend>>
+Service::CreateComputeConstantBackend() {
+ TF_ASSIGN_OR_RETURN(std::vector<se::Platform*> platforms,
+ PlatformUtil::GetSupportedPlatforms());
+ for (auto* platform : platforms) {
+ if (platform->id() == se::host::kHostPlatformId) {
+ return Backend::CreateBackend(platform, /*replica_count=*/1);
+ }
+ }
+ return NotFound("CPU platform not found");
+}
+
+/* static */ void Service::DumpExecutedHlo(const HloModule& module,
+ const string& label,
+ const HloExecutionProfile* profile) {
+ VLOG(2) << "module name = " << module.name();
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ if (!flags->xla_generate_hlo_graph.empty() &&
+ RE2::PartialMatch(module.name(), flags->xla_generate_hlo_graph)) {
+ hlo_graph_dumper::DumpGraph(*module.entry_computation(), label,
+ flags->xla_hlo_graph_addresses,
+ flags->xla_hlo_graph_layout, profile);
+ }
+ if (!flags->xla_log_hlo_text.empty() &&
+ RE2::PartialMatch(module.name(), flags->xla_log_hlo_text)) {
+ LOG(INFO) << "HLO for module " << module.name();
+ LOG(INFO) << "Label: " << label;
+ XLA_LOG_LINES(2, module.ToString());
+ }
+ if (!flags->xla_dump_hlo_text_to.empty()) {
+ hlo_graph_dumper::DumpText(module, label, flags->xla_dump_hlo_text_to);
+ }
+}
+
+/* static */ Compiler::HloDumper Service::MakeHloDumper() {
+ return [](const HloModule& module, const string& label) {
+ return DumpExecutedHlo(module, label, /*profile=*/nullptr);
+ };
+}
+
+Service::Service(std::unique_ptr<Backend> execute_backend,
+ std::unique_ptr<Backend> compute_constant_backend)
+ : execute_backend_(std::move(execute_backend)),
+ compute_constant_backend_(std::move(compute_constant_backend)) {
+ LOG(INFO) << "XLA service executing computations on platform "
+ << execute_backend_->platform()->Name() << ". Devices:";
+ for (int i = 0; i < execute_backend_->device_count(); ++i) {
+ if (execute_backend_->device_ordinal_supported(i)) {
+ se::StreamExecutor* executor =
+ execute_backend_->stream_executor(i).ValueOrDie();
+ const auto& description = executor->GetDeviceDescription();
+ LOG(INFO) << tensorflow::strings::Printf(
+ " StreamExecutor device (%d): %s, %s", i, description.name().c_str(),
+ description.platform_version().c_str());
+ } else {
+ LOG(INFO) << tensorflow::strings::Printf(
+ " StreamExecutor device (%d) not supported", i);
+ }
+ }
+}
+
+tensorflow::Status Service::Computation(const ComputationRequest* arg,
+ ComputationResponse* result) {
+ if (arg->name().empty()) {
+ return InvalidArgument("computation request needs a name");
+ }
+
+ *result->mutable_computation() =
+ computation_tracker_.NewComputation(arg->name());
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::CreateChannelHandle(
+ const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) {
+ *result->mutable_channel() = channel_tracker_.NewChannel();
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) {
+ return allocation_tracker_.Unregister(arg->data());
+}
+
+// Deconstructs a previously-allocated global handle.
+tensorflow::Status Service::DeconstructTuple(const DeconstructTupleRequest* arg,
+ DeconstructTupleResponse* result) {
+ TF_ASSIGN_OR_RETURN(
+ std::vector<GlobalDataHandle> elements,
+ allocation_tracker_.DeconstructTuple(arg->tuple_handle()));
+
+ for (auto& element : elements) {
+ *result->add_element_handles() = element;
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::ValidateResultShapeWithLayout(
+ const Shape& shape_with_layout, const Shape& result_shape) const {
+ if (!ShapeUtil::Compatible(shape_with_layout, result_shape)) {
+ return InvalidArgument(
+ "Shape used to set computation result layout %s is not compatible "
+ "with result shape %s",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str(),
+ ShapeUtil::HumanString(result_shape).c_str());
+ }
+ if (!LayoutUtil::HasLayout(shape_with_layout)) {
+ return InvalidArgument(
+ "Shape used to set computation result layout %s does not have layout",
+ ShapeUtil::HumanStringWithLayout(shape_with_layout).c_str());
+ }
+ return ShapeUtil::ValidateShape(shape_with_layout);
+}
+
+StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
+ tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
+ const Backend* backend, int device_ordinal) {
+ std::vector<const Allocation*> allocations;
+ for (int i = 0; i < arguments.size(); ++i) {
+ auto allocation_status = allocation_tracker_.Resolve(*arguments[i]);
+ if (!allocation_status.ok()) {
+ return Status(allocation_status.status().code(),
+ tensorflow::strings::StrCat(
+ allocation_status.status().error_message(), ", ",
+ "failed to resolve allocation for parameter ", i));
+ }
+ const Allocation* allocation = allocation_status.ValueOrDie();
+
+ // Verify allocation is same platform and device as the execution.
+ if (allocation->backend() != backend ||
+ allocation->device_ordinal() != device_ordinal) {
+ return InvalidArgument(
+ "argument %d is on device %s but computation will be executed "
+ "on device %s",
+ i,
+ allocation->backend()
+ ->device_name(allocation->device_ordinal())
+ .c_str(),
+ backend->device_name(device_ordinal).c_str());
+ }
+
+ allocations.push_back(allocation);
+ }
+ return allocations;
+}
+
+StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
+ const ProgramShape& program_shape,
+ tensorflow::gtl::ArraySlice<const Allocation*> arguments,
+ const Shape* shape_with_output_layout, uint64 seed) {
+ auto module_config = MakeUnique<HloModuleConfig>(program_shape);
+ auto* computation_layout = module_config->mutable_entry_computation_layout();
+
+ if (program_shape.parameters_size() != arguments.size()) {
+ return InvalidArgument("computation takes %d parameters, but %zu given",
+ program_shape.parameters_size(), arguments.size());
+ }
+
+ for (int i = 0; i < arguments.size(); ++i) {
+ // Verify that shape of arguments matches the shape of the arguments in the
+ // ProgramShape.
+ if (!ShapeUtil::Compatible(arguments[i]->shape(),
+ program_shape.parameters(i))) {
+ return InvalidArgument(
+ "computation expects parameter %d to have shape %s, given shape %s",
+ i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
+ ShapeUtil::HumanString(arguments[i]->shape()).c_str());
+ }
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
+ arguments[i]->shape()));
+ }
+ if (shape_with_output_layout == nullptr) {
+ computation_layout->mutable_result_layout()->Clear();
+ } else {
+ TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(*shape_with_output_layout,
+ program_shape.result()));
+ TF_RETURN_IF_ERROR(
+ computation_layout->mutable_result_layout()->CopyLayoutFromShape(
+ *shape_with_output_layout));
+ }
+
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ if (flags->xla_hlo_profile) {
+ module_config->enable_hlo_profiling(true);
+ }
+
+ module_config->set_seed(seed);
+ module_config->set_replica_count(execute_backend_->Replicas().size());
+
+ return std::move(module_config);
+}
+
+StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
+ std::vector<VersionedComputationHandle> versioned_handles,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ Backend* backend,
+ std::vector<perftools::gputools::StreamExecutor*> executors) {
+ // Dump computation proto state if flag is set.
+ std::vector<std::unique_ptr<SessionModule>> session_modules;
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ const string& directory_path = flags->xla_dump_computations_to;
+ const string& other_directory_path = flags->xla_dump_executions_to;
+ if ((!directory_path.empty() || !other_directory_path.empty())) {
+ for (int64 i = 0; i < versioned_handles.size(); ++i) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<SessionModule> session_module,
+ computation_tracker_.SnapshotComputation(
+ versioned_handles[i].handle));
+ if (!directory_path.empty()) {
+ string filename =
+ tensorflow::strings::Printf("computation_%lld__%s__version_%lld",
+ versioned_handles[i].handle.handle(),
+ session_module->entry().name().c_str(),
+ versioned_handles[i].version);
+ TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
+ *session_module));
+ session_modules.push_back(std::move(session_module));
+ }
+ }
+ }
+
+ VLOG(1) << "building executables from:";
+ for (const VersionedComputationHandle& versioned_handle : versioned_handles) {
+ VLOG(1) << versioned_handle.handle.handle() << "@v"
+ << versioned_handle.version;
+ }
+
+ std::vector<std::unique_ptr<HloModule>> modules;
+ for (const VersionedComputationHandle& versioned_handle : versioned_handles) {
+ TF_ASSIGN_OR_RETURN(auto module,
+ computation_tracker_.BuildHloModule(
+ versioned_handle,
+ /*include_unused_parameters=*/true));
+ modules.push_back(std::move(module));
+ }
+
+ Compiler::HloDumper hlo_dumper = MakeHloDumper();
+ TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<Executable>> executables,
+ backend->compiler()->Compile(
+ std::move(modules), std::move(module_configs),
+ hlo_dumper, std::move(executors)));
+
+ if (!other_directory_path.empty()) {
+ for (int64 i = 0; i < versioned_handles.size(); ++i) {
+ executables[i]->set_session_module(std::move(session_modules[i]));
+ }
+ }
+
+ return std::move(executables);
+}
+
+StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
+ const VersionedComputationHandle& versioned_handle,
+ std::unique_ptr<HloModuleConfig> module_config,
+ bool executable_for_compute_constant,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Backend* backend, se::StreamExecutor* executor) {
+ // Dump computation proto state if flag is set.
+ std::unique_ptr<SessionModule> session_module;
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ const string& directory_path = flags->xla_dump_computations_to;
+ const string& other_directory_path = flags->xla_dump_executions_to;
+ if (!executable_for_compute_constant &&
+ (!directory_path.empty() || !other_directory_path.empty())) {
+ TF_ASSIGN_OR_RETURN(
+ session_module,
+ computation_tracker_.SnapshotComputation(versioned_handle.handle));
+ if (!directory_path.empty()) {
+ string filename = tensorflow::strings::Printf(
+ "computation_%lld__%s__version_%lld",
+ versioned_handle.handle.handle(),
+ session_module->entry().name().c_str(), versioned_handle.version);
+ TF_RETURN_IF_ERROR(Executable::DumpToDirectory(directory_path, filename,
+ *session_module));
+ }
+ }
+
+ VLOG(1) << tensorflow::strings::Printf("building executable %lld@v%lld",
+ versioned_handle.handle.handle(),
+ versioned_handle.version);
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModule> module,
+ computation_tracker_.BuildHloModule(
+ versioned_handle,
+ /*include_unused_parameters=*/!executable_for_compute_constant));
+
+ Compiler::HloDumper hlo_dumper = MakeHloDumper();
+ if (executable_for_compute_constant &&
+ !flags->xla_hlo_graph_for_compute_constant) {
+ hlo_dumper = [](const HloModule&, const string&) {};
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ backend->compiler()->Compile(std::move(module), std::move(module_config),
+ hlo_dumper, executor));
+
+ if (!other_directory_path.empty()) {
+ executable->set_session_module(std::move(session_module));
+ }
+
+ return std::move(executable);
+}
+
+StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
+ const VersionedComputationHandle& versioned_handle,
+ std::unique_ptr<HloModuleConfig> module_config,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Backend* backend, perftools::gputools::StreamExecutor* executor,
+ ExecutionProfile* profile) {
+ std::shared_ptr<Executable> executable =
+ compilation_cache_.LookUp(versioned_handle, *module_config);
+
+ if (executable != nullptr) {
+ // Executable found in the computation cache.
+ if (profile != nullptr) {
+ profile->set_compilation_cache_hit(true);
+ }
+ return executable;
+ }
+
+ uint64 start_micros =
+ // Avoid reading the clock if we don't want timing info
+ (profile != nullptr) ? tensorflow::Env::Default()->NowMicros() : 0;
+
+ // Take a copy of the module config, as compilation introduces layouts where
+ // layouts were optional before.
+ HloModuleConfig original_module_config = *module_config;
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable_unique_ptr,
+ BuildExecutable(versioned_handle, std::move(module_config),
+ /*executable_for_compute_constant=*/false, arguments,
+ execute_backend_.get(), executor));
+
+ if (profile != nullptr) {
+ uint64 end_micros = tensorflow::Env::Default()->NowMicros();
+ uint64 milliseconds = (end_micros - start_micros) / 1000;
+ profile->set_compilation_cache_hit(false);
+ profile->set_compile_time_ms(milliseconds);
+ }
+
+ // Insert executable into the cache.
+ return compilation_cache_.Insert(std::move(executable_unique_ptr),
+ original_module_config);
+}
+
+StatusOr<std::vector<GlobalDataHandle>>
+Service::ExecuteParallelAndRegisterResult(
+ tensorflow::gtl::ArraySlice<Executable*> executables,
+ tensorflow::gtl::ArraySlice<
+ std::vector<perftools::gputools::DeviceMemoryBase>>
+ arguments,
+ Backend* backend,
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*> executors,
+ tensorflow::gtl::ArraySlice<string> result_tags) {
+ // TODO(b/33943292): Support for replication when using multiple computations.
+ TF_RET_CHECK(backend->Replicas().size() == 1);
+
+ // Set up streams.
+ std::vector<std::unique_ptr<se::Stream>> streams;
+
+ auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() {
+ for (std::unique_ptr<se::Stream>& stream : streams) {
+ backend->ReleaseStream(std::move(stream));
+ }
+ });
+
+ for (se::StreamExecutor* executor : executors) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
+ backend->AcquireStream(executor));
+ // Push back after so that the releaser only sees real streams.
+ streams.push_back(std::move(stream));
+ }
+
+ // Set up run options.
+ std::vector<ExecutableRunOptions> run_options;
+ for (const std::unique_ptr<se::Stream>& stream : streams) {
+ run_options.emplace_back();
+ auto& options = run_options.back();
+ options.set_stream(stream.get());
+ options.set_allocator(backend->memory_allocator());
+ options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
+ options.set_intra_op_thread_pool(
+ backend->eigen_intra_op_thread_pool_device());
+ }
+
+ // Asynchronously launch all executables.
+ std::vector<GlobalDataHandle> result_handles;
+ for (int64 i = 0; i < executables.size(); i++) {
+ TF_ASSIGN_OR_RETURN(
+ perftools::gputools::DeviceMemoryBase result,
+ executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i]));
+ result_handles.push_back(allocation_tracker_.Register(
+ backend, executors[i]->device_ordinal(), result,
+ executables[i]->result_shape(), result_tags[i]));
+ }
+
+ // Wait for all executions to complete.
+ for (int64 i = 0; i < result_handles.size(); ++i) {
+ if (!streams[i]->BlockHostUntilDone()) {
+ return InternalError("failed to complete execution for stream %lld", i);
+ }
+ }
+
+ return result_handles;
+}
+
+StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
+ Executable* executable,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Backend* backend, perftools::gputools::StreamExecutor* executor,
+ const string& result_tag, ExecutionProfile* profile) {
+ TF_RET_CHECK(!backend->Replicas().empty());
+
+ // Set up streams.
+ std::vector<std::unique_ptr<se::Stream>> streams;
+
+ auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() {
+ for (std::unique_ptr<se::Stream>& stream : streams) {
+ backend->ReleaseStream(std::move(stream));
+ }
+ });
+
+ for (se::StreamExecutor* executor : backend->Replicas()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
+ backend->AcquireStream(executor));
+ // Push back after so that the releaser only sees real streams.
+ streams.push_back(std::move(stream));
+ }
+
+ // Set up run options.
+ std::vector<ExecutableRunOptions> run_options;
+ for (const std::unique_ptr<se::Stream>& stream : streams) {
+ run_options.emplace_back();
+ auto& options = run_options.back();
+ options.set_stream(stream.get());
+ options.set_allocator(backend->memory_allocator());
+ options.set_inter_op_thread_pool(backend->inter_op_thread_pool());
+ options.set_intra_op_thread_pool(
+ backend->eigen_intra_op_thread_pool_device());
+ }
+
+ perftools::gputools::DeviceMemoryBase result;
+ if (backend->Replicas().size() == 1) {
+ TF_ASSIGN_OR_RETURN(
+ result, ExecuteOnStreamWrapper<StatusOr<se::DeviceMemoryBase>>(
+ executable, &run_options[0], profile,
+ [&arguments](Executable* executable,
+ const ExecutableRunOptions* run_options,
+ HloExecutionProfile* hlo_execution_profile) {
+ return executable->ExecuteOnStream(run_options, arguments,
+ hlo_execution_profile);
+ }));
+ } else {
+ std::vector<
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>>
+ repeated_arguments(backend->Replicas().size(), arguments);
+
+ TF_ASSIGN_OR_RETURN(
+ auto results,
+ executable->ExecuteOnStreams(run_options, repeated_arguments));
+ TF_RET_CHECK(!results.empty());
+ result = results[0];
+ }
+ return allocation_tracker_.Register(backend, executor->device_ordinal(),
+ result, executable->result_shape(),
+ result_tag);
+}
+
+tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg,
+ SetReturnValueResponse* results) {
+ TF_ASSIGN_OR_RETURN(UserComputation * computation,
+ computation_tracker_.Resolve(arg->computation()));
+ return computation->SetReturnValue(arg->operand());
+}
+
+tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) {
+ VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString();
+
+ std::vector<std::vector<se::DeviceMemoryBase>> all_arguments;
+ std::vector<perftools::gputools::StreamExecutor*> executors;
+ std::vector<VersionedComputationHandle> versioned_handles;
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
+ std::vector<string> computation_names;
+
+ if (arg->requests_size() > execute_backend_->stream_executors().size()) {
+ return FailedPrecondition(
+ "there are not enough stream executors to execute %d computations",
+ arg->requests_size());
+ }
+
+ for (int64 i = 0; i < arg->requests_size(); ++i) {
+ // Get the stream executor on which the computation will run. Select the
+ // specific device if requested, otherwise select the i'th device from the
+ // list of available stream executors.
+ se::StreamExecutor* executor;
+ if (arg->requests(i).has_device_handle()) {
+ executor =
+ execute_backend_
+ ->stream_executors()[arg->requests(i).device_handle().handle()];
+ } else {
+ executor = execute_backend_->stream_executors()[i];
+ }
+ CHECK(executor != nullptr);
+
+ // Resolve the UserComputation object associated with the requested
+ // computation and compute the program shape.
+ const ExecuteRequest& request = arg->requests(i);
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(request.computation()));
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+ if (user_computation->request_count(versioned_handle.version) == 0) {
+ return InvalidArgument("computations may not be empty");
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+ const Shape* shape_with_output_layout =
+ request.has_shape_with_output_layout()
+ ? &request.shape_with_output_layout()
+ : nullptr;
+
+ // Resolve the allocations for the arguments of the computation, and create
+ // a vector of device memory offsets for the arguments from the allocations.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<const Allocation*> arg_allocations,
+ ResolveAndValidateArguments(request.arguments(), execute_backend_.get(),
+ executor->device_ordinal()));
+ std::vector<se::DeviceMemoryBase> arguments;
+ for (const Allocation* allocation : arg_allocations) {
+ arguments.push_back(allocation->device_memory());
+ }
+
+ // Create an HloModuleConfig object for the computation, given the shape of
+ // the program and the argument allocations.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, arg_allocations,
+ shape_with_output_layout, request.seed()));
+ VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: "
+ << module_config->entry_computation_layout().ToString();
+
+ // Adds to the vectors to build and execute the computations after the loop.
+ all_arguments.push_back(arguments);
+ versioned_handles.push_back(versioned_handle);
+ module_configs.push_back(std::move(module_config));
+ computation_names.push_back(user_computation->name());
+ executors.push_back(executor);
+ }
+
+ // Build the user computations into HloModules and compile to generate the
+ // executables.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<std::unique_ptr<Executable>> executables,
+ BuildExecutables(versioned_handles, std::move(module_configs),
+ execute_backend_.get(), executors));
+ std::vector<Executable*> executable_ptrs;
+ for (const auto& executable : executables) {
+ executable_ptrs.push_back(executable.get());
+ }
+
+ // Execute the generated executables in parallel and return the device
+ // handles for each computation's output.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<GlobalDataHandle> outputs,
+ ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments,
+ execute_backend_.get(), executors,
+ computation_names));
+ for (const GlobalDataHandle& output : outputs) {
+ ExecuteResponse response;
+ *response.mutable_output() = output;
+ *result->add_responses() = response;
+ }
+
+ VLOG(1) << "successfully completed 'execute-parallel' request";
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) {
+ const int64 available_device_count =
+ execute_backend_->stream_executors().size();
+ const int64 replicas = execute_backend_->Replicas().size();
+ if (available_device_count < arg->device_count() * replicas) {
+ return ResourceExhausted(
+ "Requested device count (%lld) exceeds the number of available devices "
+ "on the target (%lld)",
+ arg->device_count(), available_device_count);
+ }
+
+ for (int64 i = 0; i < arg->device_count(); ++i) {
+ DeviceHandle device_handle;
+ device_handle.set_handle(
+ execute_backend_->stream_executors()[i * replicas]->device_ordinal());
+ *result->add_device_handles() = device_handle;
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::Execute(const ExecuteRequest* arg,
+ ExecuteResponse* result) {
+ VLOG(1) << "running execute request: " << arg->ShortDebugString();
+
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+
+ if (user_computation->request_count(versioned_handle.version) == 0) {
+ return InvalidArgument("computations may not be empty");
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<const Allocation*> arg_allocations,
+ ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(),
+ execute_backend_->default_device_ordinal()));
+
+ const Shape* shape_with_output_layout = arg->has_shape_with_output_layout()
+ ? &arg->shape_with_output_layout()
+ : nullptr;
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, arg_allocations,
+ shape_with_output_layout, arg->seed()));
+
+ VLOG(3) << "Execute created HloModuleConfig computation layout: "
+ << module_config->entry_computation_layout().ToString();
+
+ std::vector<se::DeviceMemoryBase> arguments;
+ for (const Allocation* allocation : arg_allocations) {
+ arguments.push_back(allocation->device_memory());
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<Executable> executable,
+ BuildAndCacheExecutable(versioned_handle, std::move(module_config),
+ arguments, execute_backend_.get(),
+ execute_backend_->default_stream_executor(),
+ result->mutable_profile()));
+
+ if (executable->dumping()) {
+ executable->session_module()->set_execution_platform(
+ execute_backend_->platform()->Name());
+ TF_RETURN_IF_ERROR(
+ RecordArguments(arg_allocations, executable->session_module()));
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ *result->mutable_output(),
+ ExecuteAndRegisterResult(
+ executable.get(), arguments, execute_backend_.get(),
+ execute_backend_->default_stream_executor(),
+ "result of " + user_computation->name(), result->mutable_profile()));
+
+ if (executable->dumping()) {
+ TF_ASSIGN_OR_RETURN(const Allocation* result_allocation,
+ allocation_tracker_.Resolve(result->output()));
+ TF_RETURN_IF_ERROR(
+ RecordResult(result_allocation, executable->session_module()));
+ TF_RETURN_IF_ERROR(executable->DumpSessionModule());
+ }
+
+ VLOG(1) << "successfully completed 'execute' request";
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) {
+ VLOG(1) << "running execute-async request: " << arg->ShortDebugString();
+
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+ if (user_computation->request_count(versioned_handle.version) == 0) {
+ return InvalidArgument("computations may not be empty");
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> program_shape,
+ user_computation->ComputeProgramShape(versioned_handle.version));
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<const Allocation*> arg_allocations,
+ ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(),
+ execute_backend_->default_device_ordinal()));
+
+ const Shape* shape_with_output_layout = arg->has_shape_with_output_layout()
+ ? &arg->shape_with_output_layout()
+ : nullptr;
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(*program_shape, arg_allocations,
+ shape_with_output_layout, arg->seed()));
+
+ VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: "
+ << module_config->entry_computation_layout().ToString();
+
+ std::vector<se::DeviceMemoryBase> arguments;
+ for (const Allocation* allocation : arg_allocations) {
+ arguments.push_back(allocation->device_memory());
+ }
+
+ ExecutionProfile profile;
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<Executable> executable,
+ BuildAndCacheExecutable(versioned_handle, std::move(module_config),
+ arguments, execute_backend_.get(),
+ execute_backend_->default_stream_executor(),
+ &profile));
+
+ TF_RET_CHECK(!execute_backend_->Replicas().empty());
+ // Set up streams.
+ std::vector<std::unique_ptr<se::Stream>> streams;
+
+ auto stream_releaser = ::tensorflow::gtl::MakeCleanup([this, &streams]() {
+ for (std::unique_ptr<se::Stream>& stream : streams) {
+ execute_backend_->ReleaseStream(std::move(stream));
+ }
+ });
+
+ for (se::StreamExecutor* executor : execute_backend_->Replicas()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
+ execute_backend_->AcquireStream(executor));
+ // Push back after so that the releaser only sees real streams.
+ streams.push_back(std::move(stream));
+ }
+
+ perftools::gputools::DeviceMemoryBase result_data;
+ for (const std::unique_ptr<se::Stream>& stream : streams) {
+ ExecutableRunOptions options;
+ options.set_stream(stream.get());
+ options.set_allocator(execute_backend_->memory_allocator());
+ options.set_inter_op_thread_pool(execute_backend_->inter_op_thread_pool());
+ options.set_intra_op_thread_pool(
+ execute_backend_->eigen_intra_op_thread_pool_device());
+
+ TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase this_result_data,
+ executable->ExecuteAsyncOnStream(&options, arguments));
+
+ // Take the first result.
+ if (result_data == nullptr) {
+ result_data = this_result_data;
+ }
+ }
+
+ auto output = allocation_tracker_.Register(
+ execute_backend_.get(), execute_backend_->default_device_ordinal(),
+ result_data, executable->result_shape(),
+ "result of " + user_computation->name());
+
+ *result->mutable_execution() = execution_tracker_.Register(
+ execute_backend_.get(), std::move(streams), profile, output);
+ streams.clear();
+
+ VLOG(1) << "successfully completed 'execute-async' request";
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) {
+ TF_ASSIGN_OR_RETURN(const auto execution,
+ execution_tracker_.Resolve(arg->execution()));
+
+ TF_RETURN_IF_ERROR(execution->BlockUntilDone());
+
+ *result->mutable_output() = execution->result();
+ *result->mutable_profile() = execution->profile();
+
+ TF_RETURN_IF_ERROR(execution_tracker_.Unregister(arg->execution()));
+ VLOG(1) << "successfully completed 'wait-for-execution' request";
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg,
+ TransferToClientResponse* result) {
+ TF_ASSIGN_OR_RETURN(const Allocation* allocation,
+ allocation_tracker_.Resolve(arg->data()));
+
+ const Shape* literal_shape;
+ if (arg->has_shape_with_layout()) {
+ if (!LayoutUtil::HasLayout(arg->shape_with_layout())) {
+ return InvalidArgument("shape_with_layout must have layout if present.");
+ }
+ literal_shape = &arg->shape_with_layout();
+ } else {
+ literal_shape = &allocation->shape();
+ }
+
+ return LiteralFromAllocation(allocation, *literal_shape,
+ result->mutable_literal());
+}
+
+tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
+ TransferToServerResponse* result) {
+ const Literal& literal = arg->literal();
+ const Shape& shape = literal.shape();
+
+ if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
+ // TODO(b/32990684): Tuple transfers to host end up allocating further
+ // buffers - implement that correctly.
+ return Unimplemented(
+ "Tuple transfers to the device not supported with replication.");
+ }
+
+ se::StreamExecutor* stream_executor;
+ if (arg->has_device_handle()) {
+ TF_ASSIGN_OR_RETURN(
+ stream_executor,
+ execute_backend_->stream_executor(arg->device_handle().handle()));
+ } else {
+ stream_executor = execute_backend_->default_stream_executor();
+ }
+
+ // Allocate memory on the device, using the stream executor. The size of the
+ // allocation is obtained by examining the shape of the literal passed from
+ // the client. An allocation handle is returned in the response.
+ int64 allocation_size =
+ execute_backend_->transfer_manager()->GetByteSizeRequirement(shape);
+
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
+ execute_backend_->memory_allocator()->Allocate(
+ stream_executor->device_ordinal(), allocation_size));
+
+ *result->mutable_data() = allocation_tracker_.Register(
+ execute_backend_.get(), stream_executor->device_ordinal(), allocation,
+ shape, tensorflow::strings::StrCat("TransferToServer literal of size ",
+ allocation_size));
+
+ TF_ASSIGN_OR_RETURN(
+ auto replicas,
+ execute_backend_->Replicas(stream_executor->device_ordinal()));
+ for (se::StreamExecutor* executor : replicas) {
+ TF_RETURN_IF_ERROR(
+ execute_backend_->transfer_manager()->TransferLiteralToDevice(
+ executor, literal, &allocation));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) {
+ const int64 replica_count = execute_backend_->Replicas().size();
+ if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) {
+ return FailedPrecondition(
+ "%s",
+ tensorflow::strings::StrCat(
+ "The replica_id=", arg->replica_id(),
+ " on TransferToInfeedRequest not in range [0, replica_count=",
+ replica_count, ").")
+ .c_str());
+ }
+
+ se::StreamExecutor* executor;
+ if (arg->has_device_handle()) {
+ TF_ASSIGN_OR_RETURN(
+ auto replicas,
+ execute_backend_->Replicas(arg->device_handle().handle()));
+ executor = replicas[arg->replica_id()];
+ } else {
+ executor = execute_backend_->Replicas()[arg->replica_id()];
+ }
+
+ return execute_backend_->transfer_manager()->TransferLiteralToInfeed(
+ executor, arg->literal());
+}
+
+tensorflow::Status Service::ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) {
+ int first_device_ordinal = arg->has_device_handle()
+ ? arg->device_handle().handle()
+ : execute_backend_->default_device_ordinal();
+ TF_ASSIGN_OR_RETURN(auto executors,
+ execute_backend_->Replicas(first_device_ordinal));
+ for (se::StreamExecutor* executor : executors) {
+ TF_RETURN_IF_ERROR(
+ execute_backend_->transfer_manager()->ResetDevice(executor));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::TransferToClientInProcess(
+ const TransferToClientInProcessRequest* arg,
+ TransferToClientInProcessResponse* result) {
+ TF_RETURN_IF_ERROR(CheckRunsInClientProcess("TransferToClientInProcess"));
+
+ TF_ASSIGN_OR_RETURN(const Allocation* allocation,
+ allocation_tracker_.Resolve(arg->data()));
+
+ void* buffer = reinterpret_cast<void*>(arg->buffer());
+ int64 size = ShapeUtil::ByteSizeOf(allocation->shape());
+ TF_ASSIGN_OR_RETURN(
+ se::StreamExecutor * executor,
+ allocation->backend()->stream_executor(allocation->device_ordinal()));
+
+ return allocation->backend()->transfer_manager()->TransferBufferFromDevice(
+ executor, allocation->device_memory(), size, buffer);
+}
+
+tensorflow::Status Service::TransferToServerInProcess(
+ const TransferToServerInProcessRequest* arg,
+ TransferToServerInProcessResponse* result) {
+ TF_RETURN_IF_ERROR(CheckRunsInClientProcess("TransferToServerInProcess"));
+
+ const Shape& shape = arg->shape();
+
+ if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) {
+ // TODO(b/32990684): Tuple transfers to host end up allocating further
+ // buffers - implement that correctly.
+ return Unimplemented(
+ "Tuple transfers to the device not supported with replication.");
+ }
+
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("shape must have layout");
+ }
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
+
+ const void* buffer = reinterpret_cast<const void*>(arg->buffer());
+
+ // Allocate memory on the device, using the stream executor. The size of the
+ // allocation is obtained by examining the shape of the literal passed from
+ // the client. An allocation handle is returned in the response.
+ int64 allocation_size =
+ execute_backend_->transfer_manager()->GetByteSizeRequirement(shape);
+ se::StreamExecutor* stream_executor =
+ execute_backend_->default_stream_executor();
+
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation,
+ execute_backend_->memory_allocator()->Allocate(
+ stream_executor->device_ordinal(), allocation_size));
+
+ *result->mutable_data() = allocation_tracker_.Register(
+ execute_backend_.get(), stream_executor->device_ordinal(), allocation,
+ shape, tensorflow::strings::StrCat("TransferToServer literal of size ",
+ allocation_size));
+
+ for (se::StreamExecutor* executor : execute_backend_->Replicas()) {
+ TF_RETURN_IF_ERROR(
+ execute_backend_->transfer_manager()->TransferBufferToDevice(
+ executor, allocation_size, buffer, &allocation));
+ }
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandleAtOperation(arg->operand());
+
+ if (user_computation->request_count(versioned_handle.version) == 0) {
+ return InvalidArgument("computations may not be empty");
+ }
+
+ TF_ASSIGN_OR_RETURN(bool is_constant,
+ user_computation->IsConstant(arg->operand()));
+
+ result->set_is_constant(is_constant);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
+ ComputeConstantResponse* result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandleAtOperation(arg->operand());
+
+ if (user_computation->request_count(versioned_handle.version) == 0) {
+ return InvalidArgument("computations may not be empty");
+ }
+
+ TF_ASSIGN_OR_RETURN(bool is_constant,
+ user_computation->IsConstant(arg->operand()));
+
+ if (!is_constant) {
+ return InvalidArgument("Operand to ComputeConstant depends on parameter.");
+ }
+
+ // We can't use ComputeProgramShape because it checks that all parameter
+ // instructions are present and contiguous. Instead construct ProgramShape
+ // directly.
+ ProgramShape program_shape;
+ TF_ASSIGN_OR_RETURN(*program_shape.mutable_result(),
+ user_computation->GetShape(arg->operand()));
+
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
+
+ Shape shape_with_output_layout(program_shape.result());
+ if (arg->has_output_layout()) {
+ TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
+ arg->output_layout(), shape_with_output_layout));
+ *shape_with_output_layout.mutable_layout() = arg->output_layout();
+ }
+
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<HloModuleConfig> module_config,
+ CreateModuleConfig(
+ program_shape, {},
+ arg->has_output_layout() ? &shape_with_output_layout : nullptr,
+ /*seed=*/0));
+
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<Executable> executable,
+ BuildExecutable(versioned_handle, std::move(module_config),
+ /*executable_for_compute_constant=*/true,
+ /*arguments=*/{}, compute_constant_backend_.get(),
+ compute_constant_backend_->default_stream_executor()));
+
+ TF_ASSIGN_OR_RETURN(
+ *result->mutable_output(),
+ ExecuteAndRegisterResult(
+ executable.get(), /*arguments=*/{}, compute_constant_backend_.get(),
+ compute_constant_backend_->default_stream_executor(),
+ "constant computed from " + user_computation->name(),
+ /*profile=*/nullptr));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::GetShape(const GetShapeRequest* arg,
+ GetShapeResponse* result) {
+ TF_ASSIGN_OR_RETURN(const Allocation* allocation,
+ allocation_tracker_.Resolve(arg->data()));
+ *result->mutable_shape() = allocation->shape();
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::GetComputationShape(
+ const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ VersionedComputationHandle versioned_handle =
+ computation->GetVersionedHandle();
+
+ TF_ASSIGN_OR_RETURN(
+ auto program_shape,
+ computation->ComputeProgramShape(versioned_handle.version));
+ *result->mutable_program_shape() = *program_shape;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ TF_ASSIGN_OR_RETURN(*result->mutable_shape(),
+ computation->GetShape(arg->operand()));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::GetComputationStats(
+ const ComputationStatsRequest* arg, ComputationStatsResponse* result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ VersionedComputationHandle versioned_handle =
+ user_computation->GetVersionedHandle();
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ computation_tracker_.BuildHloModule(versioned_handle));
+
+ MakeHloDumper()(*module, "computation statistics subject");
+
+ // Run HLO analysis to get the computation statistics.
+ HloCostAnalysis analysis;
+
+ TF_RETURN_IF_ERROR(
+ module->entry_computation()->root_instruction()->Accept(&analysis));
+
+ ComputationStats stats;
+ stats.set_flop_count(analysis.flop_count());
+ stats.set_transcendental_count(analysis.transcendental_count());
+ *result->mutable_stats() = stats;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::CheckRunsInClientProcess(
+ const string& method_name) const {
+ if (runs_in_client_process_) {
+ return tensorflow::Status::OK();
+ } else {
+ return FailedPrecondition(
+ "%s only supported if service runs in the same process as the client",
+ method_name.c_str());
+ }
+}
+
+template <typename RequestT, typename ResponseT>
+tensorflow::Status Service::AddInstruction(
+ const RequestT* arg, ResponseT* result,
+ const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
+ adder) {
+ TF_ASSIGN_OR_RETURN(UserComputation * computation,
+ computation_tracker_.Resolve(arg->computation()));
+
+ TF_ASSIGN_OR_RETURN(*result->mutable_output(), adder(computation));
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
+ TF_ASSIGN_OR_RETURN(UserComputation * computation,
+ computation_tracker_.Resolve(arg->computation()));
+ StatusOr<ComputationDataHandle> handle;
+
+ switch (arg->op_case()) {
+ case OpRequest::kBinaryOpRequest:
+ handle = computation->AddBinaryInstruction(arg->binary_op_request());
+ break;
+ case OpRequest::kBroadcastRequest:
+ handle = computation->AddBroadcastInstruction(arg->broadcast_request());
+ break;
+ case OpRequest::kCallRequest: {
+ TF_ASSIGN_OR_RETURN(
+ UserComputation * to_apply,
+ computation_tracker_.Resolve(arg->call_request().to_apply()));
+ handle = computation->AddCallInstruction(arg->call_request(), *to_apply);
+ break;
+ }
+ case OpRequest::kConcatenateRequest:
+ handle =
+ computation->AddConcatenateInstruction(arg->concatenate_request());
+ break;
+ case OpRequest::kConstantRequest:
+ handle = computation->AddConstantInstruction(arg->constant_request());
+ break;
+ case OpRequest::kConvertRequest:
+ handle = computation->AddConvertInstruction(arg->convert_request());
+ break;
+ case OpRequest::kConvolveRequest:
+ handle = computation->AddConvolveInstruction(arg->convolve_request());
+ break;
+ case OpRequest::kCrossReplicaSumRequest:
+ handle = computation->AddCrossReplicaSumInstruction(
+ arg->cross_replica_sum_request());
+ break;
+ case OpRequest::kCustomCallRequest:
+ handle =
+ computation->AddCustomCallInstruction(arg->custom_call_request());
+ break;
+ case OpRequest::kDynamicSliceRequest:
+ handle =
+ computation->AddDynamicSliceInstruction(arg->dynamic_slice_request());
+ break;
+ case OpRequest::kDynamicUpdateSliceRequest:
+ handle = computation->AddDynamicUpdateSliceInstruction(
+ arg->dynamic_update_slice_request());
+ break;
+ case OpRequest::kGetTupleElementRequest:
+ handle = computation->AddGetTupleElementInstruction(
+ arg->get_tuple_element_request());
+ break;
+ case OpRequest::kInfeedRequest:
+ handle = computation->AddInfeedInstruction(arg->infeed_request());
+ break;
+ case OpRequest::kMapRequest: {
+ TF_ASSIGN_OR_RETURN(
+ UserComputation * to_apply,
+ computation_tracker_.Resolve(arg->map_request().to_apply()));
+ handle = computation->AddMapInstruction(arg->map_request(), *to_apply);
+ break;
+ }
+ case OpRequest::kPadRequest:
+ handle = computation->AddPadInstruction(arg->pad_request());
+ break;
+ case OpRequest::kParameterRequest:
+ handle = computation->AddParameterInstruction(arg->parameter_request());
+ break;
+ case OpRequest::kReduceRequest: {
+ TF_ASSIGN_OR_RETURN(
+ UserComputation * to_apply,
+ computation_tracker_.Resolve(arg->reduce_request().to_apply()));
+ handle =
+ computation->AddReduceInstruction(arg->reduce_request(), *to_apply);
+ break;
+ }
+ case OpRequest::kReduceWindowRequest: {
+ TF_ASSIGN_OR_RETURN(UserComputation * to_apply,
+ computation_tracker_.Resolve(
+ arg->reduce_window_request().to_apply()));
+ handle = computation->AddReduceWindowInstruction(
+ arg->reduce_window_request(), *to_apply);
+ break;
+ }
+ case OpRequest::kReshapeRequest:
+ handle = computation->AddReshapeInstruction(arg->reshape_request());
+ break;
+ case OpRequest::kReverseRequest:
+ handle = computation->AddReverseInstruction(arg->reverse_request());
+ break;
+ case OpRequest::kRngRequest:
+ handle = computation->AddRngInstruction(arg->rng_request());
+ break;
+ case OpRequest::kSelectAndScatterRequest: {
+ TF_ASSIGN_OR_RETURN(UserComputation * select,
+ computation_tracker_.Resolve(
+ arg->select_and_scatter_request().select()));
+ TF_ASSIGN_OR_RETURN(UserComputation * scatter,
+ computation_tracker_.Resolve(
+ arg->select_and_scatter_request().scatter()));
+ handle = computation->AddSelectAndScatterInstruction(
+ arg->select_and_scatter_request(), *select, *scatter);
+ break;
+ }
+ case OpRequest::kSliceRequest:
+ handle = computation->AddSliceInstruction(arg->slice_request());
+ break;
+ case OpRequest::kTernaryOpRequest:
+ handle = computation->AddTernaryInstruction(arg->ternary_op_request());
+ break;
+ case OpRequest::kTraceRequest:
+ return computation->AddTraceInstruction(arg->trace_request());
+ case OpRequest::kUnaryOpRequest:
+ handle = computation->AddUnaryInstruction(arg->unary_op_request());
+ break;
+ case OpRequest::kVariadicOpRequest:
+ handle = computation->AddVariadicInstruction(arg->variadic_op_request());
+ break;
+ case OpRequest::kWhileRequest: {
+ TF_ASSIGN_OR_RETURN(
+ UserComputation * condition,
+ computation_tracker_.Resolve(arg->while_request().condition()));
+ TF_ASSIGN_OR_RETURN(
+ UserComputation * body,
+ computation_tracker_.Resolve(arg->while_request().body()));
+ handle = computation->AddWhileInstruction(arg->while_request(),
+ *condition, *body);
+ break;
+ }
+ case OpRequest::kSendRequest: {
+ TF_RETURN_IF_ERROR(
+ channel_tracker_.RegisterSend(arg->send_request().channel_handle()));
+ TF_RETURN_IF_ERROR(computation->AddSendInstruction(arg->send_request()));
+ return tensorflow::Status::OK();
+ }
+ case OpRequest::kRecvRequest: {
+ TF_RETURN_IF_ERROR(
+ channel_tracker_.RegisterRecv(arg->recv_request().channel_handle()));
+ handle = computation->AddRecvInstruction(arg->recv_request());
+ break;
+ }
+ default:
+ return InvalidArgument("Unsupported operation");
+ }
+ TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::SnapshotComputation(
+ const SnapshotComputationRequest* arg,
+ SnapshotComputationResponse* result) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<SessionModule> module,
+ computation_tracker_.SnapshotComputation(arg->computation()));
+
+ result->set_allocated_module(module.release());
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status Service::LoadComputationSnapshot(
+ const LoadComputationSnapshotRequest* arg,
+ LoadComputationSnapshotResponse* result) {
+ TF_ASSIGN_OR_RETURN(*result->mutable_computation(),
+ computation_tracker_.LoadSessionModule(arg->module()));
+ return tensorflow::Status::OK();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h
new file mode 100644
index 0000000000..1141e99fe3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/service.h
@@ -0,0 +1,457 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
+#include "tensorflow/compiler/xla/service/allocation_tracker.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/channel_tracker.h"
+#include "tensorflow/compiler/xla/service/compilation_cache.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/computation_tracker.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/execution_tracker.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/user_computation.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/service_interface.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+namespace xla {
+
+// Options to configure the service when it is created.
+class ServiceOptions {
+ public:
+ // Set the platform backing the service, or nullptr for the default platform.
+ ServiceOptions& set_platform(perftools::gputools::Platform* platform);
+ perftools::gputools::Platform* platform() const;
+
+ // Set the number of replicas to use when compiling replicated
+ // programs. The default is -1 meaning that the value is read from
+ // the xla_replicas flag.
+ ServiceOptions& set_number_of_replicas(int number_of_replicas);
+ int number_of_replicas() const;
+
+ private:
+ perftools::gputools::Platform* platform_ = nullptr;
+ int number_of_replicas_ = -1;
+};
+
+// The XLA service object, which is the same across all
+// platforms. It maintains the service state of computations and allocations,
+// and delegates target-specific requests to the target-specific infrastructure
+// (target-specific compiler, StreamExecutor).
+class Service : public ServiceInterface {
+ public:
+ // Factory method for creating a new Service.
+ static StatusOr<std::unique_ptr<Service>> NewService(
+ perftools::gputools::Platform* platform = nullptr);
+ static StatusOr<std::unique_ptr<Service>> NewService(
+ const ServiceOptions& options);
+
+ // Creates a new computation with the given name.
+ // A unique ComputationHandle is returned.
+ tensorflow::Status Computation(const ComputationRequest* arg,
+ ComputationResponse* result) override;
+
+ // Unregisters a previously-allocated global handle.
+ //
+ // If the handle given is not currently allocated, a NOT_FOUND status is
+ // returned.
+ tensorflow::Status Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) override;
+
+ // Deconstructs a tuple. Returns a newly created GlobalDataHandle for each
+ // element in the tuple.
+ tensorflow::Status DeconstructTuple(
+ const DeconstructTupleRequest* arg,
+ DeconstructTupleResponse* result) override;
+
+ // Modifies the provided computation so that subsequent executions
+ // will compute the provided ComputationDataHandle, rather than the
+ // last expression enqueued on that Computation.
+ tensorflow::Status SetReturnValue(const SetReturnValueRequest* arg,
+ SetReturnValueResponse* results) override;
+
+ // Executes a computation with the provided global data passed as
+ // immutable arguments. Returns global data output and execution timing.
+ tensorflow::Status Execute(const ExecuteRequest* arg,
+ ExecuteResponse* result) override;
+
+ // Executes one or more computations in parallel with the provided global data
+ // passed as immutable arguments. Returns global data output for each
+ // computation.
+ tensorflow::Status ExecuteParallel(const ExecuteParallelRequest* arg,
+ ExecuteParallelResponse* result) override;
+
+ // Requests one or more device handles from the target.
+ //
+ // When N device handles are requested and the number of replicas is R, at
+ // least N * R devices must be available. The devices are assigned based on
+ // the device ordinals such that the first R available devices are assigned to
+ // the first set of replicas, and the next R devices to the second set of
+ // replicas, etc. Each returned device handles represent the device with the
+ // replica id 0.
+ tensorflow::Status GetDeviceHandles(
+ const GetDeviceHandlesRequest* arg,
+ GetDeviceHandlesResponse* result) override;
+
+ // Asynchronously executes a computation with provided arguments. Invokes
+ // the provided computation with the provided global data passed as
+ // immutable arguments. Returns a handle to the execution.
+ tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) override;
+
+ // Waits until the specified execution is complete and returns the result.
+ // Calling this API multiple times with the same execution handle returns the
+ // method with an error since the execution handle is destroyed after the
+ // first call.
+ tensorflow::Status WaitForExecution(
+ const WaitForExecutionRequest* arg,
+ WaitForExecutionResponse* result) override;
+
+ // Requests that global data be transferred to the client in literal form.
+ tensorflow::Status TransferToClient(
+ const TransferToClientRequest* arg,
+ TransferToClientResponse* result) override;
+
+ // Requests that global data be copied into a buffer supplied by the client.
+ tensorflow::Status TransferToClientInProcess(
+ const TransferToClientInProcessRequest* arg,
+ TransferToClientInProcessResponse* result) override;
+
+ // Transfers data from a literal provided by the client, into device memory.
+ tensorflow::Status TransferToServer(
+ const TransferToServerRequest* arg,
+ TransferToServerResponse* result) override;
+
+ // Transfers data from a literal provided by the client, into the Infeed
+ // buffer of the device.
+ tensorflow::Status TransferToInfeed(
+ const TransferToInfeedRequest* arg,
+ TransferToInfeedResponse* result) override;
+
+ // Resets the device, clearing all existing state on the device.
+ tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) override;
+
+ // Transfers data from a buffer provided by the client, into device memory.
+ tensorflow::Status TransferToServerInProcess(
+ const TransferToServerInProcessRequest* arg,
+ TransferToServerInProcessResponse* result) override;
+
+ // Tests if an expression is a compile-time constant.
+ tensorflow::Status IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) override;
+
+ // Computes the value of a constant expression.
+ tensorflow::Status ComputeConstant(const ComputeConstantRequest* arg,
+ ComputeConstantResponse* result) override;
+
+ // Returns the shape (with layout) of an array associated with a given data
+ // handle.
+ tensorflow::Status GetShape(const GetShapeRequest* arg,
+ GetShapeResponse* result) override;
+
+ // Returns the program shape of the computation associated with the given
+ // handle.
+ tensorflow::Status GetComputationShape(
+ const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) override;
+
+ /////
+ // Computation-oriented methods.
+
+ // Enqueues an Op on the computation.
+ tensorflow::Status Op(const OpRequest* arg, OpResponse* result) override;
+
+ // Retrieves the inferred shape for a value within a computation.
+ tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) override;
+
+ // Retrieves the statistics of a computation.
+ tensorflow::Status GetComputationStats(
+ const ComputationStatsRequest* arg,
+ ComputationStatsResponse* result) override;
+
+ // Snapshots the current state of a computation handle into a serializable
+ // protocol buffer form, so it can be loaded via
+ // LoadComputationSnapshot.
+ tensorflow::Status SnapshotComputation(
+ const SnapshotComputationRequest* arg,
+ SnapshotComputationResponse* result) override;
+
+ // Loads a computation from a serialized protocol buffer created via
+ // SnapshotComputation.
+ tensorflow::Status LoadComputationSnapshot(
+ const LoadComputationSnapshotRequest* arg,
+ LoadComputationSnapshotResponse* result) override;
+
+ // Creates a unique channel handle that can be used for Send/Recv
+ // instructions.
+ tensorflow::Status CreateChannelHandle(
+ const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) override;
+
+ // Returns the ComputationTracker of the current service instance.
+ // Only used in unit tests to access user computations from client.
+ const ComputationTracker& computation_tracker() {
+ return computation_tracker_;
+ }
+
+ // Returns the backend used to execute computations.
+ const Backend& backend() const { return *execute_backend_; }
+ Backend* mutable_backend() { return execute_backend_.get(); }
+
+ protected:
+ // The constructor is private. Use the NewService factory to create new
+ // service objects.
+ Service(std::unique_ptr<Backend> backend,
+ std::unique_ptr<Backend> compute_constant_backend);
+
+ static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
+
+ // Resolves the given argument handles in the allocation tracker and returns
+ // the corresponding allocations. The function also verifies that each
+ // allocation matches the given backend and device ordinal.
+ StatusOr<std::vector<const Allocation*>> ResolveAndValidateArguments(
+ tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
+ const Backend* backend, int device_ordinal);
+
+ // Create a Hlo module config foe the given program shape and arguments. If
+ // shape_with_output_layout is not null, then the computation output layout is
+ // set to the layout of the given shape.
+ StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
+ const ProgramShape& program_shape,
+ tensorflow::gtl::ArraySlice<const Allocation*> arguments,
+ const Shape* shape_with_output_layout, uint64 seed);
+
+ // Builds an Executable for the given parameters. If
+ // executable_for_compute_constant is true, then the executable is intended to
+ // be used for ComputeConstant which means dead parameter instructions are not
+ // included in the executable.The parameter "profile" can optionally point to
+ // an ExecutionProfile object which will be filled in with profile data
+ // relevant to compilation.
+ StatusOr<std::unique_ptr<Executable>> BuildExecutable(
+ const VersionedComputationHandle& versioned_handle,
+ std::unique_ptr<HloModuleConfig> module_config,
+ bool executable_for_compute_constant,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Backend* backend, perftools::gputools::StreamExecutor* executor);
+
+ // Same as BuildExecutable() above, but builds a list of Executables for the
+ // given computations that may interact with each other.
+ StatusOr<std::vector<std::unique_ptr<Executable>>> BuildExecutables(
+ std::vector<VersionedComputationHandle> versioned_handles,
+ std::vector<std::unique_ptr<HloModuleConfig>> module_configs,
+ Backend* backend,
+ std::vector<perftools::gputools::StreamExecutor*> executors);
+
+ // Similar to BuildExecutable, but look in the compilation cache for the
+ // executable first. If the executable is not in the cache, it is built and
+ // inserted into the cache.
+ StatusOr<std::shared_ptr<Executable>> BuildAndCacheExecutable(
+ const VersionedComputationHandle& versioned_handle,
+ std::unique_ptr<HloModuleConfig> module_config,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Backend* backend, perftools::gputools::StreamExecutor* executor,
+ ExecutionProfile* profile);
+
+ // Runs the given executable with the given arguments and register the result
+ // in the allocation tracker. The handle of the result from the tracker is
+ // returned. If the parameter "profile" is not null, it points to an
+ // ExecutionProfile object which will be filled in with profile data.
+ StatusOr<GlobalDataHandle> ExecuteAndRegisterResult(
+ Executable* executable,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Backend* backend, perftools::gputools::StreamExecutor* executor,
+ const string& result_tag, ExecutionProfile* profile);
+
+ // Runs the given executables with the given arguments and register the result
+ // from each executable in the allocation tracker. The handles of the result
+ // from the tracker are returned.
+ StatusOr<std::vector<GlobalDataHandle>> ExecuteParallelAndRegisterResult(
+ tensorflow::gtl::ArraySlice<Executable*> executables,
+ tensorflow::gtl::ArraySlice<
+ std::vector<perftools::gputools::DeviceMemoryBase>>
+ arguments,
+ Backend* backend,
+ tensorflow::gtl::ArraySlice<perftools::gputools::StreamExecutor*>
+ executors,
+ tensorflow::gtl::ArraySlice<string> result_tags);
+
+ // Dumps the executed HLO according to service-associated flags.
+ static void DumpExecutedHlo(const HloModule& module, const string& label,
+ const HloExecutionProfile* profile);
+
+ // Returns an HLO dumper for use in the compiler (it refers to flags
+ // associated with the service).
+ static Compiler::HloDumper MakeHloDumper();
+
+ // Convenience function for adding a function to a user computation.
+ template <typename RequestT, typename ResponseT>
+ tensorflow::Status AddInstruction(
+ const RequestT* arg, ResponseT* result,
+ const std::function<StatusOr<ComputationDataHandle>(UserComputation*)>&
+ adder);
+
+ // If the service is running in the client process
+ // (runs_in_client_process_ is true) then return
+ // tensorflow::Status::OK. Otherwise return an appropriate error
+ // status with the given method name. Used for "InProcess" methods.
+ tensorflow::Status CheckRunsInClientProcess(const string& method_name) const;
+
+ // Convenience function which checks whether the given shape_with_layout
+ // (presumably passed by the client to set the result layout) is valid for the
+ // given computation result shape.
+ tensorflow::Status ValidateResultShapeWithLayout(
+ const Shape& shape_with_layout, const Shape& result_shape) const;
+
+ // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a
+ // timer for the execution, sets up HLO profiling if enabled, and fills in the
+ // given ExecutionProfile if non-null. The given execute_func should be a
+ // function which calls the desired ExecuteOnStream overload with the supplied
+ // arguments. The ExecuteOnStream overloads return different types so this
+ // method is templated on return-type of the execute function.
+ template <typename ReturnT>
+ ReturnT ExecuteOnStreamWrapper(
+ Executable* executable, const ExecutableRunOptions* run_options,
+ ExecutionProfile* profile,
+ std::function<ReturnT(Executable* executable,
+ const ExecutableRunOptions* run_options,
+ HloExecutionProfile* hlo_execution_profile)>
+ execute_func);
+
+ // Tracks computations built via the API.
+ ComputationTracker computation_tracker_;
+
+ // Tracks channels created via the API.
+ ChannelTracker channel_tracker_;
+
+ // Tracks allocations made via the API and computation execution.
+ AllocationTracker allocation_tracker_;
+
+ // Tracks asynchronously launched executions via the API.
+ ExecutionTracker execution_tracker_;
+
+ // Cache containing previously built Executables.
+ CompilationCache compilation_cache_;
+
+ // Backend to compile and execute computations on.
+ //
+ // TODO(b/28616830): Support multiple backends for execution.
+ std::unique_ptr<Backend> execute_backend_;
+
+ // Backend to use when executing ComputeConstant.
+ std::unique_ptr<Backend> compute_constant_backend_;
+
+ // Whether the service runs in the same process as the client.
+ bool runs_in_client_process_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Service);
+};
+
+template <typename ReturnT>
+ReturnT Service::ExecuteOnStreamWrapper(
+ Executable* executable, const ExecutableRunOptions* run_options,
+ ExecutionProfile* profile,
+ std::function<ReturnT(Executable* executable,
+ const ExecutableRunOptions* run_options,
+ HloExecutionProfile* hlo_execution_profile)>
+ execute_func) {
+ perftools::gputools::Stream* stream = run_options->stream();
+ std::unique_ptr<perftools::gputools::Timer> timer;
+ if (profile != nullptr) {
+ timer.reset(new perftools::gputools::Timer(stream->parent()));
+ stream->InitTimer(timer.get()).ThenStartTimer(timer.get());
+ }
+
+ VLOG(1) << "enqueueing executable on stream...";
+ // If the profiling flag isn't enabled, we pass nullptr as the profile to
+ // indicate profiling is not requested.
+ HloExecutionProfile hlo_execution_profile;
+ legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags();
+ HloExecutionProfile* profile_ptr =
+ flags->xla_hlo_profile && executable->hlo_profiling_enabled()
+ ? &hlo_execution_profile
+ : nullptr;
+
+ auto return_value = execute_func(executable, run_options, profile_ptr);
+
+ if (profile != nullptr) {
+ VLOG(1) << "enqueueing 'stop timer' and blocking host until done...";
+ stream->ThenStopTimer(timer.get()).BlockHostUntilDone();
+ VLOG(1) << "done with block-host-until-done";
+
+ // Merge in run time profile information from the executable.
+ profile->MergeFrom(executable->execution_profile());
+
+ // Overall execution time (in nanoseconds) from the executor timer.
+ profile->set_compute_and_transfer_time_ns(timer->Nanoseconds());
+
+ // TODO(b/28123297): On GPU we end up including transfer time in
+ // the compute time this way. Instead, we should get the correct
+ // value by measuring it. Setting the field here at least lets
+ // benchmarks provide *some* value for GPU computations.
+ //
+ // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually
+ // the compute time without the transfer time, so this way we get the
+ // correct compute time. We should instead have the correct value for
+ // compute_and_transfer_time and set compute_time to the compute time.
+ if (profile->compute_time_ns() == 0) {
+ profile->set_compute_time_ns(profile->compute_and_transfer_time_ns());
+ }
+ }
+
+ if (profile_ptr != nullptr) {
+ HloCostAnalysis analysis;
+ tensorflow::Status analysis_status =
+ executable->module().entry_computation()->root_instruction()->Accept(
+ &analysis);
+ if (analysis_status.ok()) {
+ XLA_LOG_LINES(tensorflow::INFO,
+ profile_ptr->ToString(
+ stream->parent()->GetDeviceDescription(), analysis));
+ }
+ DumpExecutedHlo(executable->module(), "Service::Execute", profile_ptr);
+ }
+
+ return return_value;
+}
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SERVICE_H_
diff --git a/tensorflow/compiler/xla/service/session.proto b/tensorflow/compiler/xla/service/session.proto
new file mode 100644
index 0000000000..ead3c1eb10
--- /dev/null
+++ b/tensorflow/compiler/xla/service/session.proto
@@ -0,0 +1,91 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This proto file defines messages which store the state of XLA
+// computations within the XLA service. A computation is stored as a record
+// of the operation requests used to build it.
+syntax = "proto3";
+
+import "tensorflow/compiler/xla/xla_data.proto";
+
+package xla;
+
+// Describes a single operation request.
+message OperationRequest {
+ ComputationDataHandle output_handle = 1;
+ Shape output_shape = 2;
+
+ // For operations which call embedded computations such as "Map", these are
+ // the version(s) that the embedded computation should be called at. A version
+ // value of a computation is the ComputationDataHandle of the root of the
+ // computation at the point in time.
+ //
+ // "Call", "Map", "Reduce", and "ReduceWindow" operations take a single
+ // embedded computation so this field will have a single value for those
+ // operations.
+ //
+ // "While" operation takes two; index 0 is the "condition" version and index 1
+ // is the "body" version.
+ repeated int64 embedded_computation_versions = 3;
+
+ // The actual request, which in itself is a tagged union of all possible
+ // operation request types.
+ OpRequest request = 4;
+}
+
+// Describes a sequence of operation requests which define an XLA
+// computation.
+message SessionComputation {
+ string name = 1;
+
+ // The ComputationHandle used to refer to this computation in the XLA
+ // service.
+ ComputationHandle computation_handle = 2;
+
+ // Map from ComputationDataHandle value to operation request. The highest
+ // ComputationDataHandle value corresponds to the root of the computation.
+ map<int64, OperationRequest> requests = 3;
+
+ // The list of Trace requests in this SessionComputation.
+ repeated TraceRequest trace_requests = 4;
+
+ // The list of Send requests in this SessionComputation.
+ repeated SendRequest send_requests = 5;
+}
+
+// Describes a group of SessionComputations with an "entry point" computation
+// that may refer to the other non-entry (AKA embedded) computations.
+//
+// This message is used to serialize a computation that has been built via the
+// XLA service API, along with its dependencies, for purposes such as
+// analysis/replay/file-storage.
+message SessionModule {
+ // The entry computation, which was requested for serialization. This may have
+ // referred to embedded computations, which are reflected below.
+ SessionComputation entry = 1;
+
+ // Embedded computations that are transitively referred to by the entry
+ // computation.
+ repeated SessionComputation embedded_computations = 2;
+
+ // The arguments passed to the computation.
+ repeated Literal arguments = 3;
+
+ // The result of the computation.
+ Literal result = 4;
+
+ // The name of the platform used to run the computation.
+ string execution_platform = 5;
+}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
new file mode 100644
index 0000000000..11559ad757
--- /dev/null
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -0,0 +1,1380 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+
+#include <stddef.h>
+#include <algorithm>
+#include <numeric>
+#include <set>
+#include <string>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+
+namespace {
+
+// Returns true if no element is present in slice more than once.
+bool AllUnique(tensorflow::gtl::ArraySlice<int64> slice) {
+ return std::set<int64>(slice.begin(), slice.end()).size() == slice.size();
+}
+
+tensorflow::Status ExpectNotTupleOrOpaque(const Shape& shape,
+ tensorflow::StringPiece op_type) {
+ if (ShapeUtil::IsTuple(shape)) {
+ return InvalidArgument("Expected non-tuple argument for %s. Got: %s",
+ op_type.ToString().c_str(),
+ ShapeUtil::HumanString(shape).c_str());
+ } else if (ShapeUtil::IsOpaque(shape)) {
+ return InvalidArgument("Expected non-opaque argument for %s. Got: %s",
+ op_type.ToString().c_str(),
+ ShapeUtil::HumanString(shape).c_str());
+ } else {
+ return tensorflow::Status::OK();
+ }
+}
+
+tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
+ const Shape& init_value_shape,
+ const PrimitiveType& input_element_type) {
+ if (reducer_shape.parameters_size() != 2) {
+ return InvalidArgument(
+ "Reduction function must take 2 parameters, but "
+ "takes %d parameter(s).",
+ reducer_shape.parameters_size());
+ }
+
+ const Shape& accumulator_shape = reducer_shape.result();
+ if (ShapeUtil::Rank(accumulator_shape) != 0) {
+ return Unimplemented(
+ "Reduction function currently must have rank-0 result.");
+ }
+
+ // Check that the accumulator can be passed in as the first argument.
+ // Note: comparing here and below with Compatible since we don't care about
+ // layout in scalars - see b/26668201 for a longer-term vision.
+ if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(0))) {
+ return InvalidArgument(
+ "Reduction function's first parameter shape differs from the "
+ "result shape: %s vs %s",
+ ShapeUtil::HumanString(reducer_shape.parameters(0)).c_str(),
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
+
+ // Check that init_value's shape is suitable for reducer_shape.
+ if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
+ return InvalidArgument(
+ "Reduction function's accumulator shape differs from the "
+ "init_value shape: %s vs %s",
+ ShapeUtil::HumanString(accumulator_shape).c_str(),
+ ShapeUtil::HumanString(init_value_shape).c_str());
+ }
+
+ // Check that the inputs can be passed in as the second argument.
+ const Shape& input_element_shape =
+ ShapeUtil::MakeShape(input_element_type, {});
+ if (!ShapeUtil::Compatible(input_element_shape,
+ reducer_shape.parameters(1))) {
+ return InvalidArgument(
+ "Reduction function's second parameter shape differs from the "
+ "input type element type: %s vs %s",
+ ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
+ ShapeUtil::HumanString(input_element_shape).c_str());
+ }
+
+ // Currently the accumulator and inputs must be the same type,
+ // though that restriction could be relaxed.
+ if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
+ return InvalidArgument(
+ "Reduction function's second parameter shape currently must "
+ "match the result shape. Got %s vs %s",
+ ShapeUtil::HumanString(reducer_shape.parameters(1)).c_str(),
+ ShapeUtil::HumanString(accumulator_shape).c_str());
+ }
+
+ return tensorflow::Status::OK();
+}
+
+StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
+ const Window& window,
+ PrimitiveType element_type,
+ bool allow_negative_padding) {
+ if (window.dimensions_size() != ShapeUtil::Rank(base_shape)) {
+ return InvalidArgument(
+ "Window has dimension %d but base shape has dimension %lld.",
+ window.dimensions_size(), ShapeUtil::Rank(base_shape));
+ }
+
+ std::vector<int64> output_dimensions(window.dimensions_size());
+ for (int64 i = 0; i < window.dimensions_size(); ++i) {
+ const auto& dim = window.dimensions(i);
+ if (dim.size() <= 0) {
+ return InvalidArgument("Window has a non-positive dimension. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (dim.stride() <= 0) {
+ return InvalidArgument("Window has a non-positive stride. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (!allow_negative_padding && dim.padding_low() < 0) {
+ return InvalidArgument("Window has a negative low padding. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (!allow_negative_padding && dim.padding_high() < 0) {
+ return InvalidArgument("Window has a negative high padding. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (dim.base_dilation() < 1) {
+ return InvalidArgument(
+ "Window has a non-positive base area dilation factor. Window: %s",
+ window.DebugString().c_str());
+ }
+ if (dim.window_dilation() < 1) {
+ return InvalidArgument(
+ "Window has a non-positive window dilation factor. Window: %s",
+ window.DebugString().c_str());
+ }
+
+ const int64 dilated_base = window_util::DilatedBound(
+ ShapeUtil::GetDimension(base_shape, i), dim.base_dilation());
+ const int64 padded_dilated_base =
+ dim.padding_low() + dilated_base + dim.padding_high();
+ const int64 dilated_window =
+ window_util::DilatedBound(dim.size(), dim.window_dilation());
+
+ output_dimensions[i] = window_util::StridedBound(
+ padded_dilated_base, dilated_window, dim.stride());
+ }
+
+ return ShapeUtil::MakeShape(element_type, output_dimensions);
+}
+
+} // namespace
+
+/* static */ StatusOr<Shape> ShapeInference::InferUnaryOpShape(
+ UnaryOperation operation, const Shape& arg) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of unary operation"));
+
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(arg));
+ switch (operation) {
+ case UNOP_FLOOR:
+ case UNOP_CEIL:
+ case UNOP_EXP:
+ case UNOP_LOG:
+ case UNOP_TANH:
+ if (!ShapeUtil::ElementIsFloating(arg)) {
+ return InvalidArgument(
+ "expected element type in shape to be floating for exp/log/tanh "
+ "operation; got %s",
+ PrimitiveType_Name(arg.element_type()).c_str());
+ }
+ return arg;
+ case UNOP_ABS:
+ case UNOP_SIGN:
+ case UNOP_NEGATE:
+ case UNOP_SORT:
+ return arg;
+
+ case UNOP_LOGICAL_NOT:
+ if (arg.element_type() != PRED) {
+ return InvalidArgument(
+ "expected pred element type in argument to logical-not operation; "
+ "got %s",
+ PrimitiveType_Name(arg.element_type()).c_str());
+ }
+ return arg;
+ default:
+ return InvalidArgument("unknown operation %s",
+ UnaryOperation_Name(operation).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferConcatOpShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const int64 dimension) {
+ if (arg_shapes.size() == 0) {
+ return InvalidArgument("Concatenate expects at least one argument");
+ }
+ if (dimension < 0 || dimension >= ShapeUtil::Rank(*arg_shapes[0])) {
+ return InvalidArgument("dimension to concatenate along out of bounds: %lld",
+ dimension);
+ }
+ const Shape* arg_shape = nullptr;
+ for (const Shape* shape : arg_shapes) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
+ if (!arg_shape) {
+ arg_shape = shape;
+ continue;
+ }
+ if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
+ return InvalidArgument(
+ "cannot concatenate arrays with different ranks: %lld vs %lld",
+ ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape));
+ }
+ if (arg_shape->element_type() != shape->element_type()) {
+ return InvalidArgument(
+ "cannot concatenate arrays with different element types: %s vs %s",
+ PrimitiveType_Name(arg_shape->element_type()).c_str(),
+ PrimitiveType_Name(shape->element_type()).c_str());
+ }
+ for (int64 dimension_number = 0;
+ dimension_number < ShapeUtil::Rank(*arg_shape); ++dimension_number) {
+ if (arg_shape->dimensions(dimension_number) !=
+ shape->dimensions(dimension_number)) {
+ if (dimension_number == dimension) {
+ continue; // It's okay to differ in the dimension we're
+ // concatenating.
+ }
+ return InvalidArgument(
+ "cannot concatenate arrays that differ in dimensions other than "
+ "the one being concatenated (the other array dimensions must be "
+ "the same): %s vs %s",
+ ShapeUtil::HumanString(*arg_shape).c_str(),
+ ShapeUtil::HumanString(*shape).c_str());
+ }
+ }
+ }
+
+ std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
+ arg_shape->dimensions().end());
+ for (size_t i = 1; i < arg_shapes.size(); ++i) {
+ new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
+ }
+ return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
+ const Shape& operand_shape, PrimitiveType new_element_type) {
+ if (ShapeUtil::IsTuple(operand_shape) || new_element_type == TUPLE) {
+ // Note: we may want to support tuple conversions via this operation in the
+ // future, by recursing into the tuple elements to check all sub-conversions
+ // are valid. For now we just reject them, though.
+ return InvalidArgument(
+ "cannot convert from or to tuple type; requested conversion: %s => %s",
+ ShapeUtil::HumanString(operand_shape).c_str(),
+ PrimitiveType_Name(new_element_type).c_str());
+ }
+
+ return ShapeUtil::ChangeElementType(operand_shape, new_element_type);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferPadShape(
+ const Shape& operand_shape, const Shape& padding_value_shape,
+ const PaddingConfig& padding_config) {
+ if (ShapeUtil::IsTuple(operand_shape)) {
+ return InvalidArgument(
+ "pad operation does not support tuple-shape operands");
+ }
+ if (!ShapeUtil::IsScalar(padding_value_shape)) {
+ return InvalidArgument(
+ "pad operation does not support non-scalar padding values");
+ }
+ if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) {
+ return InvalidArgument(
+ "the rank of the operand and the padding configuration do not match.");
+ }
+ std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
+ for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
+ dimensions[i] = operand_shape.dimensions(i) +
+ padding_config.dimensions(i).edge_padding_low() +
+ padding_config.dimensions(i).edge_padding_high() +
+ std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
+ padding_config.dimensions(i).interior_padding();
+ }
+ return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(const Shape& lhs,
+ const Shape& rhs) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
+
+ auto fail = [lhs, rhs](const string& addendum) -> Status {
+ string message = tensorflow::strings::Printf(
+ "cannot infer shape for dot operation: %s <dot> %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ if (!addendum.empty()) {
+ message += ": " + addendum;
+ }
+ return InvalidArgument("%s", message.c_str());
+ };
+
+ // Check if both element types are the same.
+ if (lhs.element_type() != rhs.element_type()) {
+ return fail("element types mismatch");
+ }
+
+ if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
+ ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) {
+ return fail("dot only supports rank 1 or 2");
+ }
+
+ // Determine the index of the contracted dimensions for input tensors.
+ // dimensions -1 of lhs and dimension 0 of rhs are contracted.
+ int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1);
+ int64 rhs_contracted_dimension = 0;
+
+ // Check if the contracted dimension sizes are the same.
+ if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) &&
+ rhs_contracted_dimension < ShapeUtil::Rank(rhs)) &&
+ lhs.dimensions(lhs_contracted_dimension) !=
+ rhs.dimensions(rhs_contracted_dimension)) {
+ return fail("contracted dimensions mismatch");
+ }
+
+ // The ranks of lhs and rhs are decremented by 1 respectively due to the
+ // contraction, and added for the rank of the result. When an input tensor is
+ // a scalar, its contribution to the rank of the result is 0.
+ // Generate the result dimensions in order, rhs dimensions followed by lhs
+ // dimensions except the contracted dimensions.
+ std::vector<int64> dimensions;
+ for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
+ if (i != lhs_contracted_dimension) {
+ dimensions.push_back(lhs.dimensions(i));
+ }
+ }
+ for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
+ if (i != rhs_contracted_dimension) {
+ dimensions.push_back(rhs.dimensions(i));
+ }
+ }
+ Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
+ VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
+ return result;
+}
+
+/* static */ StatusOr<Shape>
+ShapeInference::InferDegenerateDimensionBroadcastShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs) {
+ TF_RET_CHECK(ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs));
+
+ // The shapes have to be compatible. That is, if some dimension d has a
+ // different size in the two shapes, one of them has to be 1 (a "degenerate"
+ // dimension). In that case, the output shape has the non-1 dimension size
+ // from the lhs/rhs pair in every index.
+ std::vector<int64> output_dimensions(ShapeUtil::Rank(lhs));
+ for (int64 i = 0; i < ShapeUtil::Rank(lhs); ++i) {
+ if (lhs.dimensions(i) == rhs.dimensions(i)) {
+ output_dimensions[i] = lhs.dimensions(i);
+ } else if (lhs.dimensions(i) == 1) {
+ output_dimensions[i] = rhs.dimensions(i);
+ } else if (rhs.dimensions(i) == 1) {
+ output_dimensions[i] = lhs.dimensions(i);
+ } else {
+ return InvalidArgument("binary op %s with incompatible shapes: %s and %s",
+ BinaryOperation_Name(operation).c_str(),
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ }
+ }
+ return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
+ BinaryOperation operation, const Shape& smaller_shape,
+ const Shape& larger_shape,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ if (broadcast_dimensions.empty() && !ShapeUtil::IsScalar(smaller_shape)) {
+ // Reject "magic" inference for binops on different shapes, requiring
+ // the user to provide an explicit broadcast dimension in this case.
+ // See b/25177275 for more details.
+ return InvalidArgument("automatic shape inference not supported: %s and %s",
+ ShapeUtil::HumanString(smaller_shape).c_str(),
+ ShapeUtil::HumanString(larger_shape).c_str());
+ } else if (broadcast_dimensions.size() != ShapeUtil::Rank(smaller_shape)) {
+ return InvalidArgument(
+ "size of broadcast_dimensions has to match lower-rank operand's "
+ "rank; "
+ " lower-rank operand's rank is %lld, size of broadcast_dimensions is "
+ "%zu",
+ ShapeUtil::Rank(smaller_shape), broadcast_dimensions.size());
+ }
+
+ // broadcast_dimensions is a sequence of dimensions; its length is equal to
+ // the rank of the lower-rank operand. The lower-rank operand's dimensions
+ // have to be compatible with the higher-rank operand's dimensions at indices
+ // specified by broadcast_dimensions. Here compatible means the dimension
+ // sizes are equal or in one of the shapes the dimension size is
+ // one. Examples:
+ //
+ // smaller_shape larger_shape broadcast_dimensions output_shape
+ // [] [2, 3] {} [2, 3]
+ // [3] [4, 3] {1} [4, 3]
+ // [2, 3] [2, 3, 4] {0, 1} [2, 3, 4]
+ // [2, 1] [2, 3, 4] {0, 2} [2, 3, 1]
+ // [2, 3] [2, 1, 4] {0, 1} [2, 3, 4]
+ //
+ // The column output_shape may not be the final shape of the XLA
+ // operation. After the "InDim" broadcasting implemented in this function
+ // expands the rank, degenerate-dimension broadcasting (implemented in
+ // InferDegenerateDimensionBroadcastShape) broadcasts dimensions of size one
+ // up to match the dimension size of the other operand. For example, consider
+ // the row in the table above with a smaller_shape of [2, 1]. The shape
+ // returned by this function is [2, 3, 1] (output_shape) however, the result
+ // shape of the XLA operation is [2, 3, 4] after degenerate-dimension
+ // broadcasting.
+ //
+ // Invalid broadcasts:
+ //
+ // smaller_shape=[3], larger_shape=[4, 3], broadcast_dimensions={0}
+ // Reason: Dimension zero** of larger_shape (size 4) is not compatible with
+ // dimension zero of smaller_shape(size 3). **Zero here comes from the value
+ // in broadcast_dimensions.
+ //
+ // smaller_shape=[2, 1], larger_shape=[2, 3, 4], broadcast_dimensions={1, 2}
+ // Reason: Dimension one of larger_shape (size 3) is not compatible with
+ // dimension zero of smaller_shape(size 2)
+
+ // The output shape is initially the larger_shape. Sizes of dimensions
+ // specified in broadcast_dimensions are then changed to match the
+ // corresponding dimension size in smaller_shape.
+ Shape output_shape(larger_shape);
+
+ for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
+ int64 dimension_to_match = broadcast_dimensions.at(i);
+ if (dimension_to_match < 0) {
+ return InvalidArgument(
+ "broadcast dimension number (%lld) cannot be negative",
+ dimension_to_match);
+ }
+ if (dimension_to_match >= larger_shape.dimensions_size()) {
+ return InvalidArgument(
+ "broadcast dimension number (%lld) too large; higher-rank "
+ "operand has rank %d",
+ dimension_to_match, larger_shape.dimensions_size());
+ }
+ int64 small_dimension_size = smaller_shape.dimensions(i);
+ int64 large_dimension_size = larger_shape.dimensions(dimension_to_match);
+ // Dimension sizes must be compatible: match or be degenerate (degenerate
+ // case is handled by degenerate dimension broadcasting which occurs after
+ // InDim broadcasting).
+ if (small_dimension_size != large_dimension_size &&
+ small_dimension_size != 1 && large_dimension_size != 1) {
+ return InvalidArgument(
+ "broadcast dimension %d mismatch: %lld != %lld; %s and %s", i,
+ small_dimension_size, large_dimension_size,
+ ShapeUtil::HumanString(smaller_shape).c_str(),
+ ShapeUtil::HumanString(larger_shape).c_str());
+ }
+ // Make sure the broadcast dimensions are listed in a strictly increasing
+ // order.
+ if (i > 0 && broadcast_dimensions.at(i - 1) >= dimension_to_match) {
+ return InvalidArgument(
+ "broadcast dimensions order is wrong: %lld comes after %lld",
+ dimension_to_match, broadcast_dimensions.at(i - 1));
+ }
+
+ output_shape.set_dimensions(dimension_to_match, small_dimension_size);
+ }
+
+ return output_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferElementwiseBinaryOpShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(lhs, "lhs of elementwise binary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
+
+ if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ return InvalidArgument("binary op with different element types: %s and %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ }
+
+ if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) &&
+ !broadcast_dimensions.empty()) {
+ return InvalidArgument(
+ "broadcast dimensions field should not be set on binary "
+ "operations with operands of the same rank");
+ }
+
+ if (ShapeUtil::Compatible(lhs, rhs)) {
+ // If the shapes are the same other than layout, the output shape is the
+ // same (elementwise op).
+ return lhs;
+ }
+
+ if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
+ return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
+ } else {
+ // Ranks do not match, so perform InDim broadcasting using
+ // broadcast_dimensions. Scalar broadcasting is a special case of this).
+ const Shape& larger_shape =
+ ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
+ const Shape& smaller_shape =
+ ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? rhs : lhs;
+
+ // After InDim broadcasting, perform degenerate dimensions broadcasting.
+ TF_ASSIGN_OR_RETURN(
+ Shape indim_broadcast_shape,
+ InferInDimBroadcastShape(operation, smaller_shape, larger_shape,
+ broadcast_dimensions));
+
+ return InferDegenerateDimensionBroadcastShape(
+ operation, indim_broadcast_shape, larger_shape);
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferBinaryOpShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
+ VLOG(2) << tensorflow::strings::Printf(
+ "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}",
+ BinaryOperation_Name(operation).c_str(),
+ ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(),
+ tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str());
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs));
+
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of binary operation"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of binary operation"));
+ switch (operation) {
+ case BINOP_DOT:
+ return InferDotOpShape(lhs, rhs);
+ case BINOP_MAX:
+ case BINOP_MIN:
+ case BINOP_SUB:
+ case BINOP_ADD:
+ case BINOP_POW:
+ case BINOP_DIV:
+ case BINOP_REM:
+ case BINOP_MUL:
+ return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ broadcast_dimensions);
+
+ case BINOP_LOGICAL_AND:
+ case BINOP_LOGICAL_OR:
+ if (lhs.element_type() != PRED) {
+ return InvalidArgument(
+ "expected pred element type in argument to logical and/or "
+ "operation; got %s",
+ PrimitiveType_Name(lhs.element_type()).c_str());
+ }
+ return InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ broadcast_dimensions);
+
+ case BINOP_EQ:
+ case BINOP_GE:
+ case BINOP_GT:
+ case BINOP_LE:
+ case BINOP_LT:
+ case BINOP_NE: {
+ TF_ASSIGN_OR_RETURN(const Shape& shape,
+ InferElementwiseBinaryOpShape(operation, lhs, rhs,
+ broadcast_dimensions));
+ return ShapeUtil::ChangeElementType(shape, PRED);
+ }
+ case BINOP_INDEX:
+ if (ShapeUtil::Rank(lhs) > 0 && ShapeUtil::Rank(rhs) == 0) {
+ tensorflow::gtl::ArraySlice<int64> dimensions =
+ AsInt64Slice(lhs.dimensions());
+ dimensions.pop_front();
+ return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+ }
+ return Unimplemented("cannot infer shape for operation: %s <%s> %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ BinaryOperation_Name(operation).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ default:
+ return Unimplemented(
+ "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s",
+ BinaryOperation_Name(operation).c_str(),
+ lhs.ShortDebugString().c_str(), rhs.ShortDebugString().c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferTernaryOpShape(
+ TernaryOperation operation, const Shape& lhs, const Shape& rhs,
+ const Shape& ehs) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(ehs));
+ switch (operation) {
+ case TRIOP_CLAMP:
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
+ if (((ShapeUtil::Compatible(lhs, rhs) || ShapeUtil::Rank(lhs) == 0) &&
+ (ShapeUtil::Compatible(rhs, ehs) || ShapeUtil::Rank(ehs) == 0))) {
+ return rhs;
+ }
+ if (ShapeUtil::Rank(rhs) == 0) {
+ if (ShapeUtil::Compatible(lhs, ehs)) {
+ return lhs;
+ }
+ return ShapeUtil::Rank(ehs) == 0 ? lhs : ehs;
+ }
+ return Unimplemented("not yet implemented: %s, %s <clamp> %s",
+ lhs.ShortDebugString().c_str(),
+ ehs.ShortDebugString().c_str(),
+ rhs.ShortDebugString().c_str());
+ case TRIOP_SELECT:
+ return InferSelectShape(lhs, rhs, ehs);
+ case TRIOP_UPDATE:
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation"));
+ return lhs;
+ default:
+ return InvalidArgument("unknown operation %s",
+ TernaryOperation_Name(operation).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferVariadicOpShape(
+ VariadicOperation operation, std::vector<const Shape*> operand_shapes) {
+ for (const Shape* shape : operand_shapes) {
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(*shape));
+ }
+ switch (operation) {
+ case VAROP_TUPLE: {
+ Shape result = ShapeUtil::MakeTupleShape({});
+ for (const Shape* shape : operand_shapes) {
+ ShapeUtil::AppendShapeToTuple(*shape, &result);
+ }
+ return result;
+ }
+ default:
+ return InvalidArgument("unknown operation %s",
+ VariadicOperation_Name(operation).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferMapShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const ProgramShape& to_apply) {
+ if (arg_shapes.size() == 0) {
+ return InvalidArgument("Map expects at least one argument");
+ }
+
+ // All arguments must have the same shape.
+ const Shape* arg_shape = arg_shapes[0];
+ for (size_t i = 1; i < arg_shapes.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
+
+ if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
+ continue;
+ }
+ if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
+ !ShapeUtil::IsTuple(*arg_shape) &&
+ ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
+ if (ShapeUtil::IsScalar(*arg_shapes[i])) {
+ continue;
+ }
+ if (ShapeUtil::IsScalar(*arg_shape)) {
+ arg_shape = arg_shapes[i];
+ continue;
+ }
+ }
+
+ std::vector<string> pieces;
+ for (const Shape* shape : arg_shapes) {
+ pieces.push_back(ShapeUtil::HumanString(*shape));
+ }
+ return InvalidArgument(
+ "Map operation requires all operands to have the same shape; got: "
+ "%s",
+ tensorflow::str_util::Join(pieces, ", ").c_str());
+ }
+
+ // The applied function's arity equals the number of arguments.
+ if (arg_shapes.size() != to_apply.parameters_size()) {
+ return InvalidArgument(
+ "Map applied function arity must match number of arguments; got: "
+ "arity: %d, arguments: %zu",
+ to_apply.parameters_size(), arg_shapes.size());
+ }
+
+ // The parameters should all be scalars, and the output too.
+ const Shape& output_shape = to_apply.result();
+ if (!ShapeUtil::IsScalar(output_shape)) {
+ return InvalidArgument(
+ "mapped computation's result has to be a scalar; "
+ "got: %s",
+ ShapeUtil::HumanString(output_shape).c_str());
+ }
+
+ for (int i = 0; i < to_apply.parameters_size(); ++i) {
+ const Shape& parameter_shape = to_apply.parameters(i);
+
+ if (!ShapeUtil::IsScalar(parameter_shape)) {
+ return InvalidArgument(
+ "mapped computation's parameter has to be a scalar; "
+ "got parameter %d shape: %s",
+ i, ShapeUtil::HumanString(parameter_shape).c_str());
+ }
+
+ if (parameter_shape.element_type() != arg_shape->element_type()) {
+ return InvalidArgument(
+ "mapped computation's parameter type has to match argument element "
+ "type; got parameter %d shape: %s, argument shape: %s",
+ i, ShapeUtil::HumanString(parameter_shape).c_str(),
+ ShapeUtil::HumanString(*arg_shape).c_str());
+ }
+ }
+
+ return ShapeUtil::MakeShape(output_shape.element_type(),
+ AsInt64Slice(arg_shape->dimensions()));
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferConvolveShape(
+ const Shape& lhs, const Shape& rhs, const Window& window,
+ const ConvolutionDimensionNumbers& dnums) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
+
+ if (!ShapeUtil::SameElementType(lhs, rhs)) {
+ return InvalidArgument(
+ "Convolution with different element types: %s and %s",
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str());
+ }
+ if (dnums.spatial_dimensions_size() !=
+ dnums.kernel_spatial_dimensions_size()) {
+ return InvalidArgument(
+ "Both arguments to convolution must have same number of dimensions.\n"
+ "Window: %s",
+ window.DebugString().c_str());
+ }
+ int num_spatial_dims = dnums.spatial_dimensions_size();
+ if (num_spatial_dims < 1) {
+ return InvalidArgument(
+ "Convolution requires at least one spatial dimension.\n"
+ "Window: %s",
+ window.DebugString().c_str());
+ }
+
+ if (window.dimensions_size() != num_spatial_dims) {
+ return InvalidArgument(
+ "Window must have same number of dimensions as dimension numbers.\n"
+ "Window: %s\nDimension numbers: %s",
+ window.DebugString().c_str(), dnums.DebugString().c_str());
+ }
+
+ int num_dims = num_spatial_dims + 2;
+ if (ShapeUtil::Rank(lhs) != num_dims) {
+ return InvalidArgument(
+ "The LHS argument to a convolution should have rank %d.\n"
+ "lhs: %s",
+ num_dims, ShapeUtil::HumanString(lhs).c_str());
+ }
+ if (ShapeUtil::Rank(rhs) != num_dims) {
+ return InvalidArgument(
+ "The RHS argument to a convolution should have rank %d.\n"
+ "lhs: %s",
+ num_dims, ShapeUtil::HumanString(lhs).c_str());
+ }
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(lhs));
+ TF_DCHECK_OK(ShapeUtil::ValidateShape(rhs));
+
+ // Verifies that the input and window dimensions are a permutation of
+ // the dimension numbers.
+ std::vector<int64> input_dnums(num_dims);
+ input_dnums[0] = dnums.batch_dimension();
+ input_dnums[1] = dnums.feature_dimension();
+ std::copy(dnums.spatial_dimensions().begin(),
+ dnums.spatial_dimensions().end(), input_dnums.begin() + 2);
+ std::sort(input_dnums.begin(), input_dnums.end());
+
+ std::vector<int64> window_dnums(num_dims);
+ window_dnums[0] = dnums.kernel_input_feature_dimension();
+ window_dnums[1] = dnums.kernel_output_feature_dimension();
+ std::copy(dnums.kernel_spatial_dimensions().begin(),
+ dnums.kernel_spatial_dimensions().end(), window_dnums.begin() + 2);
+ std::sort(window_dnums.begin(), window_dnums.end());
+
+ std::vector<int64> expected_dnums(num_dims);
+ std::iota(expected_dnums.begin(), expected_dnums.end(), 0);
+
+ const auto in_range = [num_dims](int64 i) { return 0 <= i && i < num_dims; };
+ if (!std::all_of(input_dnums.begin(), input_dnums.end(), in_range) ||
+ !std::all_of(window_dnums.begin(), window_dnums.end(), in_range)) {
+ return InvalidArgument(
+ "A dimension number is out of range in convolution: %s",
+ dnums.DebugString().c_str());
+ }
+
+ if (input_dnums != expected_dnums) {
+ return InvalidArgument(
+ "Input dimensions of convolution must contain each dimension exactly "
+ "once: %s",
+ dnums.DebugString().c_str());
+ }
+ if (window_dnums != expected_dnums) {
+ return InvalidArgument(
+ "Window dimensions of convolution must contain each dimension exactly "
+ "once: %s",
+ dnums.DebugString().c_str());
+ }
+
+ std::vector<int64> input_spatial_dims(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i));
+ }
+ const int64 input_features = lhs.dimensions(dnums.feature_dimension());
+ const int64 input_batch = lhs.dimensions(dnums.batch_dimension());
+
+ std::vector<int64> kernel_spatial_dims(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ kernel_spatial_dims[i] = rhs.dimensions(dnums.kernel_spatial_dimensions(i));
+ }
+ const int64 kernel_input_features =
+ rhs.dimensions(dnums.kernel_input_feature_dimension());
+ const int64 kernel_output_features =
+ rhs.dimensions(dnums.kernel_output_feature_dimension());
+
+ if (input_features != kernel_input_features) {
+ return InvalidArgument(
+ "Expected LHS feature dimension (value %lld) to match RHS "
+ "input feature dimension (value %lld); got <conv>(%s, %s)\n"
+ "Dimension numbers: {%s}",
+ input_features, kernel_input_features,
+ ShapeUtil::HumanString(lhs).c_str(),
+ ShapeUtil::HumanString(rhs).c_str(), dnums.DebugString().c_str());
+ }
+ std::vector<int64> window_dims(num_spatial_dims);
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ window_dims[i] = window.dimensions(i).size();
+ }
+ if (kernel_spatial_dims != window_dims) {
+ return InvalidArgument(
+ "Window dimensions do not match RHS shape:\n\t"
+ "RHS shape: %s\n\t"
+ "Window: {%s}\n\t"
+ "Dimension numbers: {%s}",
+ ShapeUtil::HumanString(rhs).c_str(), window.ShortDebugString().c_str(),
+ dnums.ShortDebugString().c_str());
+ }
+
+ Shape base_shape =
+ ShapeUtil::MakeShape(lhs.element_type(), input_spatial_dims);
+ TF_ASSIGN_OR_RETURN(
+ Shape window_output_shape,
+ InferWindowOutputShape(base_shape, window, lhs.element_type(),
+ /*allow_negative_padding=*/true));
+
+ std::vector<int64> dimensions(num_dims);
+ dimensions[dnums.batch_dimension()] = input_batch;
+ dimensions[dnums.feature_dimension()] = kernel_output_features;
+ for (int i = 0; i < num_spatial_dims; ++i) {
+ dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i);
+ }
+
+ return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferCrossReplicaSumShape(
+ const Shape& operand) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand, "operand of cross replica sum"));
+ return operand;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferReduceShape(
+ const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const ProgramShape& to_apply) {
+ // Check that the dimension to reduce are in-bounds for the given shape.
+ for (int64 dimension : dimensions_to_reduce) {
+ if (dimension >= ShapeUtil::Rank(arg) || dimension < 0) {
+ return InvalidArgument(
+ "attempting to reduce out-of-bounds dimension %lld in shape %s",
+ dimension, ShapeUtil::HumanString(arg).c_str());
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ VerifyReducerShape(to_apply, init_value, arg.element_type()));
+
+ std::set<int64> dimensions_to_reduce_set(dimensions_to_reduce.begin(),
+ dimensions_to_reduce.end());
+ std::vector<int64> new_dimensions;
+ for (int i = 0; i < ShapeUtil::Rank(arg); ++i) {
+ if (dimensions_to_reduce_set.find(i) == dimensions_to_reduce_set.end()) {
+ new_dimensions.push_back(arg.dimensions(i));
+ }
+ }
+
+ return ShapeUtil::MakeShape(to_apply.result().element_type(), new_dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
+ const Shape& operand_shape, const Shape& init_value_shape,
+ const Window& window, const ProgramShape& to_apply_shape) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of reduce-window"));
+ TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_value_shape,
+ operand_shape.element_type()));
+ return InferWindowOutputShape(operand_shape, window,
+ init_value_shape.element_type(),
+ /*allow_negative_padding=*/false);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferSelectAndScatterShape(
+ const Shape& operand_shape, const ProgramShape& select_shape,
+ const Window& window, const Shape& source_shape,
+ const Shape& init_value_shape, const ProgramShape& scatter_shape) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of select-and-scatter"));
+
+ // Check if the select function has a proper shape of (T,T) -> PRED.
+ if (select_shape.parameters_size() != 2) {
+ return InvalidArgument(
+ "select function must take 2 parameters, but "
+ "takes %d parameter(s).",
+ select_shape.parameters_size());
+ }
+ const Shape& select_result_shape = select_shape.result();
+ if (!ShapeUtil::Compatible(select_result_shape,
+ ShapeUtil::MakeShape(PRED, {}))) {
+ return Unimplemented("select function must have rank-0 PRED result.");
+ }
+ const Shape& operand_element_shape =
+ ShapeUtil::MakeShape(operand_shape.element_type(), {});
+ if (!ShapeUtil::Compatible(operand_element_shape,
+ select_shape.parameters(0))) {
+ return InvalidArgument(
+ "select function's first parameter shape currently must "
+ "match the operand element shape. Got %s vs %s",
+ ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
+ ShapeUtil::HumanString(operand_element_shape).c_str());
+ }
+ if (!ShapeUtil::Compatible(operand_element_shape,
+ select_shape.parameters(1))) {
+ return InvalidArgument(
+ "select function's second parameter shape currently must "
+ "match the operand element shape. Got %s vs %s",
+ ShapeUtil::HumanString(select_shape.parameters(1)).c_str(),
+ ShapeUtil::HumanString(operand_element_shape).c_str());
+ }
+
+ // Check if the scatter function has a proper shape as a reduction.
+ TF_RETURN_IF_ERROR(VerifyReducerShape(scatter_shape, init_value_shape,
+ source_shape.element_type()));
+
+ // Check if the result shape of window operation matches the source shape.
+ TF_ASSIGN_OR_RETURN(const Shape& window_result_shape,
+ InferWindowOutputShape(operand_shape, window,
+ operand_shape.element_type(),
+ /*allow_negative_padding=*/false));
+ if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
+ return InvalidArgument(
+ "source shape does not match the shape of window-reduced operand: "
+ "source(%s), window-reduced operand(%s)",
+ ShapeUtil::HumanString(source_shape).c_str(),
+ ShapeUtil::HumanString(window_result_shape).c_str());
+ }
+ return operand_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferSliceShape(
+ const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
+ tensorflow::gtl::ArraySlice<int64> limits) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice"));
+ VLOG(2) << tensorflow::strings::Printf(
+ "slicing shape %s starts={%s} limits={%s}",
+ ShapeUtil::HumanString(arg).c_str(),
+ tensorflow::str_util::Join(starts, ", ").c_str(),
+ tensorflow::str_util::Join(limits, ", ").c_str());
+
+ if (starts.size() != limits.size()) {
+ return InvalidArgument("slice start and limit sizes differ: %zu vs %zu",
+ starts.size(), limits.size());
+ }
+
+ if (starts.size() != ShapeUtil::Rank(arg)) {
+ return InvalidArgument(
+ "slice index count does not match argument rank: %zu vs %lld",
+ starts.size(), ShapeUtil::Rank(arg));
+ }
+
+ std::vector<int64> sizes;
+ for (int64 dimension = 0; dimension < starts.size(); ++dimension) {
+ int64 start_index = starts[dimension];
+ int64 limit_index = limits[dimension];
+ if (start_index < 0) {
+ return InvalidArgument("negative start index to slice: %lld",
+ start_index);
+ }
+ if (limit_index < 0) {
+ return InvalidArgument("negative limit index to slice: %lld",
+ limit_index);
+ }
+ if (limit_index > arg.dimensions(dimension)) {
+ return InvalidArgument(
+ "limit index (%lld) must be less than or equal to dimension "
+ "size (%lld)",
+ limit_index, arg.dimensions(dimension));
+ }
+ if (start_index > limit_index) {
+ return InvalidArgument(
+ "limit index (%lld) must be greater or equal to "
+ "start index (%lld) in slice",
+ limit_index, start_index);
+ }
+ VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension,
+ start_index);
+ VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension,
+ limit_index);
+
+ sizes.push_back(limits[dimension] - starts[dimension]);
+ }
+
+ return ShapeUtil::MakeShape(arg.element_type(), sizes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferDynamicSliceShape(
+ const Shape& operand_shape, const Shape& start_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic slice"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(start_indices_shape,
+ "start indices of dynamic slice"));
+
+ VLOG(2) << tensorflow::strings::Printf(
+ "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}",
+ ShapeUtil::HumanString(operand_shape).c_str(),
+ ShapeUtil::HumanString(start_indices_shape).c_str(),
+ tensorflow::str_util::Join(slice_sizes, ", ").c_str());
+
+ if (ShapeUtil::Rank(start_indices_shape) != 1) {
+ return InvalidArgument(
+ "dynamic slice start indices of rank %lld must be rank1.",
+ ShapeUtil::Rank(start_indices_shape));
+ }
+
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
+ return InvalidArgument(
+ "dynamic slice start indices must be of integral type.");
+ }
+
+ const int64 start_num_dims = start_indices_shape.dimensions(0);
+ if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
+ return InvalidArgument(
+ "dynamic slice start number of dimensions %lld must match rank %lld of "
+ "slice input",
+ start_num_dims, ShapeUtil::Rank(operand_shape));
+ }
+
+ if (slice_sizes.size() != ShapeUtil::Rank(operand_shape)) {
+ return InvalidArgument(
+ "dynamic slice index count does not match argument rank: %zu vs %lld",
+ slice_sizes.size(), ShapeUtil::Rank(operand_shape));
+ }
+
+ for (int64 dim = 0; dim < slice_sizes.size(); ++dim) {
+ const int64 input_dim_size = operand_shape.dimensions(dim);
+ const int64 slice_dim_size = slice_sizes[dim];
+ if (slice_dim_size <= 0) {
+ return InvalidArgument("negative size index to dynamic slice: %lld",
+ slice_dim_size);
+ }
+ if (slice_dim_size > input_dim_size) {
+ return InvalidArgument(
+ "slice dim size %lld greater than dynamic slice dimension: %lld",
+ slice_dim_size, input_dim_size);
+ }
+ VLOG(2) << tensorflow::strings::Printf("slice_sizes[%lld] = %lld", dim,
+ slice_dim_size);
+ }
+
+ return ShapeUtil::MakeShape(operand_shape.element_type(), slice_sizes);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferDynamicUpdateSliceShape(
+ const Shape& operand_shape, const Shape& update_shape,
+ const Shape& start_indices_shape) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of dynamic update slice"));
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(update_shape, "update of dynamic update slice"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ start_indices_shape, "start indices of dynamic update slice"));
+
+ VLOG(2) << tensorflow::strings::Printf(
+ "updating slice of shape %s at dynamic start_indices %s with update "
+ "shape %s",
+ ShapeUtil::HumanString(operand_shape).c_str(),
+ ShapeUtil::HumanString(start_indices_shape).c_str(),
+ ShapeUtil::HumanString(update_shape).c_str());
+
+ if (ShapeUtil::Rank(start_indices_shape) != 1) {
+ return InvalidArgument(
+ "dynamic update slice start indices of rank %lld must be rank1.",
+ ShapeUtil::Rank(start_indices_shape));
+ }
+
+ if (!ShapeUtil::ElementIsIntegral(start_indices_shape)) {
+ return InvalidArgument(
+ "dynamic update slice start indices must be of integral type.");
+ }
+
+ const int64 start_num_dims = start_indices_shape.dimensions(0);
+ if (ShapeUtil::Rank(operand_shape) != start_num_dims) {
+ return InvalidArgument(
+ "dynamic update slice start number of dimensions %lld must match "
+ "rank %lld of slice input",
+ start_num_dims, ShapeUtil::Rank(operand_shape));
+ }
+
+ if (ShapeUtil::Rank(update_shape) != ShapeUtil::Rank(operand_shape)) {
+ return InvalidArgument(
+ "dynamic update slice update rank does not match argument rank: "
+ "%lld vs %lld",
+ ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
+ }
+
+ if (operand_shape.element_type() != update_shape.element_type()) {
+ return InvalidArgument(
+ "dynamic update slice update element type does not match argument. "
+ "operand.element_type: %s vs update.element_type: %s",
+ PrimitiveType_Name(operand_shape.element_type()).c_str(),
+ PrimitiveType_Name(update_shape.element_type()).c_str());
+ }
+
+ for (int64 dim = 0; dim < ShapeUtil::Rank(operand_shape); ++dim) {
+ const int64 input_dim_size = operand_shape.dimensions(dim);
+ const int64 update_dim_size = update_shape.dimensions(dim);
+ if (update_dim_size <= 0) {
+ return InvalidArgument(
+ "size index %lld to dynamic update slice must be > 0",
+ update_dim_size);
+ }
+ if (update_dim_size > input_dim_size) {
+ return InvalidArgument(
+ "update dim size %lld greater than dynamic slice dimension: %lld",
+ update_dim_size, input_dim_size);
+ }
+ VLOG(2) << tensorflow::strings::Printf("update_sizes[%lld] = %lld", dim,
+ update_dim_size);
+ }
+
+ return operand_shape;
+}
+
+/*static */ StatusOr<Shape> ShapeInference::InferReverseShape(
+ const Shape& operand_shape, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(operand_shape, "operand of reverse"));
+ if (!AllUnique(dimensions)) {
+ return InvalidArgument("a dimension number is duplicated in reverse");
+ }
+ for (int64 dimension : dimensions) {
+ if (dimension >= ShapeUtil::Rank(operand_shape) || dimension < 0) {
+ return InvalidArgument(
+ "one of the reverse dimensions (%lld) is out-of-bounds in shape %s",
+ dimension, ShapeUtil::HumanString(operand_shape).c_str());
+ }
+ }
+ return operand_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferGetTupleElementShape(
+ const Shape& arg, int64 index) {
+ if (!ShapeUtil::IsTuple(arg)) {
+ return InvalidArgument(
+ "cannot infer shape: attempting to index into non-tuple: %s",
+ ShapeUtil::HumanString(arg).c_str());
+ }
+
+ if (index >= arg.tuple_shapes_size()) {
+ return InvalidArgument(
+ "cannot infer shape: attempt to index out of tuple bounds: %lld "
+ ">= %d in shape %s",
+ index, arg.tuple_shapes_size(), ShapeUtil::HumanString(arg).c_str());
+ }
+
+ return arg.tuple_shapes(index);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferWhileShape(
+ const ProgramShape& condition, const ProgramShape& body,
+ const Shape& init) {
+ // Check the number of parameters for given computations.
+ if (condition.parameters_size() != 1) {
+ return InvalidArgument("condition must take 1 arguments; got %d",
+ condition.parameters_size());
+ }
+ if (body.parameters_size() != 1) {
+ return InvalidArgument("body must take 1 arguments; got %d",
+ body.parameters_size());
+ }
+
+ string shape_string = tensorflow::strings::Printf(
+ "condition: %s; body: %s; init: %s", condition.ShortDebugString().c_str(),
+ body.ShortDebugString().c_str(), init.ShortDebugString().c_str());
+
+ // Check the shapes of computation parameters and return types.
+ if (!ShapeUtil::ShapeIs(condition.result(), PRED, {})) {
+ return InvalidArgument("condition must return a boolean; got %s",
+ shape_string.c_str());
+ }
+ if (!ShapeUtil::Compatible(body.result(), condition.parameters(0)) ||
+ !ShapeUtil::Compatible(body.result(), body.parameters(0)) ||
+ !ShapeUtil::Compatible(body.result(), init)) {
+ return InvalidArgument(
+ "the parameter of condition and body, the result of the body, and init "
+ "must all have the same shape; got %s",
+ shape_string.c_str());
+ }
+
+ return init;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferBroadcastShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast"));
+ for (int64 size : broadcast_sizes) {
+ if (size < 0) {
+ return InvalidArgument("Broadcast with negative dimension size %lld.",
+ size);
+ }
+ }
+
+ std::vector<int64> dimensions(operand.dimensions_size() +
+ broadcast_sizes.size());
+ std::copy(broadcast_sizes.begin(), broadcast_sizes.end(), dimensions.begin());
+ std::copy(operand.dimensions().begin(), operand.dimensions().end(),
+ dimensions.begin() + broadcast_sizes.size());
+ return ShapeUtil::MakeShape(operand.element_type(), dimensions);
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferReshapeShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "reshape"));
+
+ Shape inferred_shape =
+ ShapeUtil::MakeShape(operand.element_type(), new_sizes);
+
+ if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) {
+ return InvalidArgument(
+ "reshape operation has mismatched element counts: from=%lld to=%lld",
+ ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape));
+ }
+
+ std::vector<int64> indices(ShapeUtil::Rank(operand));
+ std::iota(indices.begin(), indices.end(), 0);
+ if (dimensions.size() != ShapeUtil::Rank(operand) ||
+ !std::is_permutation(dimensions.begin(), dimensions.end(),
+ indices.begin())) {
+ return InvalidArgument(
+ "Reshape dimensions not a permutation of the operand dimensions.");
+ }
+
+ return inferred_shape;
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferTransposeShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "transpose"));
+
+ std::vector<int64> indices(ShapeUtil::Rank(operand));
+ std::iota(indices.begin(), indices.end(), 0);
+ if (dimensions.size() != ShapeUtil::Rank(operand) ||
+ !std::is_permutation(dimensions.begin(), dimensions.end(),
+ indices.begin())) {
+ return InvalidArgument(
+ "Transpose dimensions not a permutation of the operand dimensions.");
+ }
+
+ // Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
+ // we need output[i]=input[dimensions[i]] which is
+ // Permute(Inverse(dimensions),input).
+ return ShapeUtil::MakeShape(operand.element_type(),
+ Permute(InversePermutation(dimensions),
+ AsInt64Slice(operand.dimensions())));
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
+ const Shape& pred, const Shape& on_true, const Shape& on_false) {
+ if (!ShapeUtil::Compatible(on_true, on_false)) {
+ return InvalidArgument(
+ "operands to select must be the same shape; got %s and %s",
+ ShapeUtil::HumanString(on_true).c_str(),
+ ShapeUtil::HumanString(on_false).c_str());
+ }
+ if (pred.element_type() != PRED) {
+ return InvalidArgument(
+ "select's pred operand must have PRED element type; got %s",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+ if (ShapeUtil::SameDimensions(pred, on_true) || ShapeUtil::Rank(pred) == 0) {
+ // By this stage we know that pred's element type is PRED. Therefore, this
+ // check restricts pred to be a PRED scalar, or a PRED array with the same
+ // dimensions as on_true and on_false.
+ return on_true;
+ } else {
+ return Unimplemented(
+ "select operation with non-scalar predicate with dimensionality "
+ " different from the other operands: %s",
+ ShapeUtil::HumanString(pred).c_str());
+ }
+}
+
+/* static */ StatusOr<Shape> ShapeInference::InferCallShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const ProgramShape& to_apply) {
+ // The applied function's arity equals the number of arguments.
+ if (arg_shapes.size() != to_apply.parameters_size()) {
+ return InvalidArgument(
+ "Call applied function arity must match number of arguments; got: "
+ "arity: %d, arguments: %zu",
+ to_apply.parameters_size(), arg_shapes.size());
+ }
+
+ // All arguments must be compatible with the program shape.
+ for (int i = 0; i < arg_shapes.size(); ++i) {
+ const Shape& arg_shape = *arg_shapes[i];
+ const Shape& param_shape = to_apply.parameters(i);
+ if (!ShapeUtil::Compatible(arg_shape, param_shape)) {
+ return InvalidArgument(
+ "Call parameter must match argument; got parameter %d shape: %s, "
+ "argument shape: %s",
+ i, ShapeUtil::HumanString(param_shape).c_str(),
+ ShapeUtil::HumanString(arg_shape).c_str());
+ }
+ }
+
+ return to_apply.result();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
new file mode 100644
index 0000000000..ced2f4d001
--- /dev/null
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -0,0 +1,219 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Shape inference is used by the XLA service as the user builds up
+// computation requests.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// For a given operation and input shapes, infers what the resulting shape is
+// for the operation. With this functionality, the user does not need to
+// specify the expected result type for computations that are built up via the
+// API -- the shape that results from an operation is inferred.
+class ShapeInference {
+ public:
+ // Infers the shape produced by applying the given unary operation to the
+ // given input shape.
+ static StatusOr<Shape> InferUnaryOpShape(UnaryOperation operation,
+ const Shape& arg);
+
+ // Infers the shape produced by applying the given binary operation to the
+ // given input shapes.
+ static StatusOr<Shape> InferBinaryOpShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Infers the shape produced by applying the given ternary operation to the
+ // given input shapes.
+ static StatusOr<Shape> InferTernaryOpShape(TernaryOperation operation,
+ const Shape& lhs, const Shape& rhs,
+ const Shape& ehs);
+
+ // Infers the shape produced by applying the given variadic operation to the
+ // given input operand shapes.
+ static StatusOr<Shape> InferVariadicOpShape(
+ VariadicOperation operation, std::vector<const Shape*> operand_shapes);
+
+ // Infers the shape produced by applying the given mapping computation shape
+ // to the given operand shapes.
+ static StatusOr<Shape> InferMapShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const ProgramShape& to_apply);
+
+ // Infers the shape produced by applying the given convolutional
+ // filter (rhs) to lhs in the way specified by the fields on window.
+ static StatusOr<Shape> InferConvolveShape(
+ const Shape& lhs, const Shape& rhs, const Window& window,
+ const ConvolutionDimensionNumbers& dimension_numbers);
+
+ // Infers the shape produced a cross replica sum with the given operand shape.
+ static StatusOr<Shape> InferCrossReplicaSumShape(const Shape& operand);
+
+ // Infers the shape produced by applying the given reduction computation
+ // shape to the given input operand shape.
+ //
+ // If pass_index is true, the reduce function is invoked with the element
+ // index as the leading parameter, and the program shape should match
+ // accordingly (or an error will result).
+ static StatusOr<Shape> InferReduceShape(
+ const Shape& arg, const Shape& init_value,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
+ const ProgramShape& to_apply);
+
+ // Infers the shape produced by applying the given computation to the operand
+ // shape with the given window and stride dimensions.
+ static StatusOr<Shape> InferReduceWindowShape(
+ const Shape& operand_shape, const Shape& init_value, const Window& window,
+ const ProgramShape& to_apply_shape);
+
+ // Infers the shape produced by scattering the given source shape to the
+ // selected indices of each window on the operand shape.
+ static StatusOr<Shape> InferSelectAndScatterShape(
+ const Shape& operand_shape, const ProgramShape& select_shape,
+ const Window& window, const Shape& source_shape,
+ const Shape& init_value_shape, const ProgramShape& scatter_shape);
+
+ // Infers the shape produced by a reverse operation that reverses the order
+ // of the elements in the given dimensions.
+ static StatusOr<Shape> InferReverseShape(
+ const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Infers the shape produced by a slice operation spanning from the starts to
+ // the limits in the original shape's dimensions.
+ //
+ // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
+ static StatusOr<Shape> InferSliceShape(
+ const Shape& arg, tensorflow::gtl::ArraySlice<int64> starts,
+ tensorflow::gtl::ArraySlice<int64> limits);
+
+ // Infers the shape produced by a dynamic slice operation of size specified
+ // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
+ static StatusOr<Shape> InferDynamicSliceShape(
+ const Shape& operand_shape, const Shape& start_indices_shape,
+ tensorflow::gtl::ArraySlice<int64> slice_sizes);
+
+ // Infers the shape produced by a dynamic update slice operation based
+ // on the shape of operand and update.
+ static StatusOr<Shape> InferDynamicUpdateSliceShape(
+ const Shape& operand_shape, const Shape& update_shape,
+ const Shape& start_indices_shape);
+
+ // Infers the shape produced by doing a compile-time-constant indexing into
+ // the given input shape. This is essential for operations on tuples, because
+ // it is impossible to infer the type that comes out of the tuple indexing if
+ // it is not a compile time constant.
+ static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
+ int64 index);
+
+ // Infers the shape produced from a while node. condition and body are the
+ // shapes of computations for the condition and the body of a while node, and
+ // init is the shape of data initially passed in to the body as an argument.
+ // The shapes must match; condition: T -> PRED, body: T -> T, init: T
+ static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
+ const ProgramShape& body,
+ const Shape& init);
+
+ // Infers the shape produced by a broadcast operation.
+ static StatusOr<Shape> InferBroadcastShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes);
+
+ // Infers the shape produced by a reshape operation from the element type of
+ // its operand and the new dimension sizes specified.
+ static StatusOr<Shape> InferReshapeShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> new_sizes);
+
+ // Infers the shape produced by a transpose operation from the element type of
+ // its operand and its dimensions field.
+ static StatusOr<Shape> InferTransposeShape(
+ const Shape& operand, tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Helper that infers the shape produced by performing a concatenate operation
+ // with the given operand shapes.
+ static StatusOr<Shape> InferConcatOpShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes, int64 dimension);
+
+ // Helper that validates the given operand shape can be converted to the
+ // target output_shape via a convert instruction -- the requirement is that
+ // the shape is identical except for the element type.
+ static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
+ PrimitiveType new_element_type);
+
+ // Helper that infers the shape produced by a pad operation based on the
+ // padding configuration.
+ static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
+ const Shape& padding_value_shape,
+ const PaddingConfig& padding_config);
+
+ // Helper that validates the given arg_shapes are compatible with the shape of
+ // the to_apply parameters, and returns the to_apply result shape.
+ static StatusOr<Shape> InferCallShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
+ const ProgramShape& to_apply);
+
+ private:
+ // Helper that infers the shape produced by performing a dot operation with
+ // the given LHS and RHS shapes.
+ static StatusOr<Shape> InferDotOpShape(const Shape& lhs, const Shape& rhs);
+
+ // Helper that infers the shape produced by performing an element-wise binary
+ // operation with the given LHS and RHS shapes.
+ // Note: By "element-wise" we mean operations that look at a single element in
+ // the LHS and a single element in the RHS to produce a single output element,
+ // even in the presence of broadcasting of one of the operands over the other.
+ static StatusOr<Shape> InferElementwiseBinaryOpShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ // Helper for inferring the shape of Select ops.
+ static StatusOr<Shape> InferSelectShape(const Shape& pred,
+ const Shape& on_true,
+ const Shape& on_false);
+
+ // Helper for inferring shapes of binary operations which use degenerate
+ // dimension broadcasting (a dimension of size 1 in one operand is broadcast
+ // up to match the size of the dimension in the other operand).
+ static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
+ BinaryOperation operation, const Shape& lhs, const Shape& rhs);
+
+ // Helper for inferring shapes of binary operations using "InDim"
+ // broadcasting. This is the broadcasting used in the *InDim binary operations
+ // (for example ComputationBuilder::AddInDim). smaller_shape must be a
+ // lower-rank shape than larger_shape. Returns the shape that the
+ // smaller_shape is broadcast to.
+ static StatusOr<Shape> InferInDimBroadcastShape(
+ BinaryOperation operation, const Shape& smaller_shape,
+ const Shape& larger_shape,
+ tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
new file mode 100644
index 0000000000..10fd4e53c5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -0,0 +1,1133 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace {
+
+class ShapeInferenceTest : public ::testing::Test {
+ protected:
+ // Some handy scalar shapes.
+ const Shape s32_ = ShapeUtil::MakeShape(S32, {});
+ const Shape f32_ = ShapeUtil::MakeShape(F32, {});
+ const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
+
+ // Some handy vector and matrix shapes of F32 type.
+ // Suffix: vector_length_, matrix_rows_cols_
+ const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
+ const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
+ const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
+ const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
+ const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
+
+ // Some handy S32 arrays.
+ const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
+};
+
+// Subclass for testing InferReduceShape.
+class ReduceShapeInferenceTest : public ShapeInferenceTest {
+ protected:
+ // Helper that runs reduce shape inference with the input 'arg' and given
+ // dimensions to reduce, and checks the inferred shape is as expected. The
+ // element type here is hard-coded to F32.
+ void ExpectInferredReduceShape(
+ const Shape& expected_inferred_shape, const Shape& arg,
+ tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
+ auto inferred_status = ShapeInference::InferReduceShape(
+ arg, f32_, dimensions_to_reduce, to_apply);
+ EXPECT_IS_OK(inferred_status.status());
+ EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
+ inferred_status.ValueOrDie()));
+ }
+};
+
+// Subclass for testing InferSelectAndScatterShape.
+class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
+ protected:
+ SelectAndScatterShapeInferenceTest() {
+ operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
+ source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(2);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(1);
+ dim.set_base_dilation(1);
+ *window_.add_dimensions() = dim;
+ *window_.add_dimensions() = dim;
+ init_value_shape_ = ShapeUtil::MakeShape(F32, {});
+ select_program_shape_ = ShapeUtil::MakeProgramShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
+ scatter_program_shape_ = ShapeUtil::MakeProgramShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
+ }
+
+ Shape operand_shape_;
+ Shape source_shape_;
+ Window window_;
+ Shape init_value_shape_;
+ ProgramShape select_program_shape_;
+ ProgramShape scatter_program_shape_;
+};
+
+TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status = ShapeInference::InferUnaryOpShape(
+ UnaryOperation::UNOP_NEGATE, matrix_shape);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
+ Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, pred_, tuple, tuple);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_64_48_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
+ auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
+ auto inferred_status = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, predarray, matrix_64_48_, matrix_64_48_);
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, SelectBadShapes) {
+ auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, pred_, matrix_64_48_, matrix_32_64_);
+ ASSERT_FALSE(inferred_status_error1.ok());
+ ASSERT_MATCH(
+ inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("operands to select must be the same shape"));
+
+ auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, s32_, matrix_64_48_, matrix_64_48_);
+ ASSERT_FALSE(inferred_status_error2.ok());
+ ASSERT_MATCH(inferred_status_error2.status().error_message(),
+ testing::ContainsRegex("pred operand must have PRED"));
+
+ auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeShape(PRED, {64}),
+ matrix_64_48_, matrix_64_48_);
+ ASSERT_FALSE(inferred_status_error3.ok());
+ ASSERT_MATCH(
+ inferred_status_error3.status().error_message(),
+ testing::ContainsRegex("with non-scalar predicate with dimensionality"));
+
+ // Tuples have a TUPLE element type and cannot be the pred of a select.
+ auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
+ TernaryOperation::TRIOP_SELECT, ShapeUtil::MakeTupleShape({pred_, pred_}),
+ ShapeUtil::MakeTupleShape({f32_, f32_}),
+ ShapeUtil::MakeTupleShape({f32_, f32_}));
+ ASSERT_FALSE(inferred_status_error4.ok());
+ ASSERT_MATCH(
+ inferred_status_error4.status().error_message(),
+ testing::ContainsRegex("pred operand must have PRED element type"));
+}
+
+TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
+ StatusOr<Shape> result = ShapeInference::InferVariadicOpShape(
+ VariadicOperation::VAROP_TUPLE, {&s32_, &f32_});
+ ASSERT_IS_OK(result.status());
+ ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
+ ShapeUtil::MakeTupleShape({s32_, f32_})));
+}
+
+TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(2);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(1);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+ Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
+ Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
+ Shape float_scalar = ShapeUtil::MakeShape(F32, {});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
+ auto inferred_status = ShapeInference::InferReduceWindowShape(
+ matrix_shape, init_value_shape, window, to_apply);
+
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
+}
+
+TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
+ auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
+ operand_shape_, select_program_shape_, window_, source_shape_,
+ init_value_shape_, scatter_program_shape_);
+ ASSERT_IS_OK(inferred_status_ok.status());
+ Shape inferred = inferred_status_ok.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
+}
+
+TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
+ Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
+ auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
+ operand_shape_, select_program_shape_, window_, source_shape_fail,
+ init_value_shape_, scatter_program_shape_);
+ ASSERT_FALSE(inferred_status_fail.ok());
+ ASSERT_MATCH(inferred_status_fail.status().error_message(),
+ testing::ContainsRegex("source shape does not match"));
+}
+
+TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
+ ProgramShape select_program_shape_fail =
+ ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
+ auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
+ operand_shape_, select_program_shape_fail, window_, source_shape_,
+ init_value_shape_, scatter_program_shape_);
+ ASSERT_FALSE(inferred_status_fail.ok());
+ ASSERT_MATCH(
+ inferred_status_fail.status().error_message(),
+ testing::ContainsRegex("select function must take 2 parameters"));
+}
+
+TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
+ ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
+ auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
+ operand_shape_, select_program_shape_fail, window_, source_shape_,
+ init_value_shape_, scatter_program_shape_);
+ ASSERT_FALSE(inferred_status_fail.ok());
+ ASSERT_MATCH(inferred_status_fail.status().error_message(),
+ testing::ContainsRegex("select function must have rank-0 PRED"));
+}
+
+TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
+ ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
+ auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
+ operand_shape_, select_program_shape_fail, window_, source_shape_,
+ init_value_shape_, scatter_program_shape_);
+ ASSERT_FALSE(inferred_status_fail.ok());
+ ASSERT_MATCH(inferred_status_fail.status().error_message(),
+ testing::ContainsRegex("select function's first parameter"));
+}
+
+TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
+ ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
+ {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
+ auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
+ operand_shape_, select_program_shape_fail, window_, source_shape_,
+ init_value_shape_, scatter_program_shape_);
+ ASSERT_FALSE(inferred_status_fail.ok());
+ ASSERT_MATCH(inferred_status_fail.status().error_message(),
+ testing::ContainsRegex("select function's second parameter"));
+}
+
+TEST_F(ShapeInferenceTest, Convolve) {
+ ConvolutionDimensionNumbers dnums;
+
+ // Dimension order: batch, feature, x0, x1
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
+ dnums.set_batch_dimension(0);
+ dnums.set_feature_dimension(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.add_spatial_dimensions(3);
+
+ // Dimension order: x1, batch, feature, x0
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(1);
+ dnums.add_kernel_spatial_dimensions(3);
+ dnums.add_kernel_spatial_dimensions(0);
+
+ Window window;
+ auto dim0 = window.add_dimensions();
+ auto dim1 = window.add_dimensions();
+ dim0->set_size(3);
+ dim0->set_stride(2);
+ dim0->set_padding_low(1);
+ dim0->set_padding_high(1);
+ dim0->set_window_dilation(1);
+ dim0->set_base_dilation(1);
+ dim1->set_size(2);
+ dim1->set_stride(1);
+ dim1->set_padding_low(0);
+ dim1->set_padding_high(0);
+ dim1->set_window_dilation(1);
+ dim1->set_base_dilation(1);
+ auto inferred_status =
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred_shape = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
+ inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
+ ConvolutionDimensionNumbers dnums;
+
+ // Dimension order: batch, feature, x0, x1
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
+ dnums.set_batch_dimension(0);
+ dnums.set_feature_dimension(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.add_spatial_dimensions(3);
+
+ // Dimension order: x1, batch, feature, x0
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(1);
+ dnums.add_kernel_spatial_dimensions(3);
+ dnums.add_kernel_spatial_dimensions(0);
+
+ Window window;
+ auto dim0 = window.add_dimensions();
+ dim0->set_size(3);
+ dim0->set_stride(3);
+ dim0->set_padding_low(0);
+ dim0->set_padding_high(0);
+ dim0->set_window_dilation(6);
+ dim0->set_base_dilation(1);
+
+ auto dim1 = window.add_dimensions();
+ dim1->set_size(2);
+ dim1->set_stride(1);
+ dim1->set_padding_low(2);
+ dim1->set_padding_high(1);
+ dim1->set_window_dilation(2);
+ dim1->set_base_dilation(1);
+ auto inferred_status =
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred_shape = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
+ inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
+ ConvolutionDimensionNumbers dnums;
+
+ // Dimension order: batch, feature, x0, x1
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
+ dnums.set_batch_dimension(0);
+ dnums.set_feature_dimension(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.add_spatial_dimensions(3);
+
+ // Dimension order: x1, batch, feature, x0
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(1);
+ dnums.add_kernel_spatial_dimensions(3);
+ dnums.add_kernel_spatial_dimensions(0);
+
+ Window window;
+ auto dim0 = window.add_dimensions();
+ dim0->set_size(4);
+ dim0->set_stride(3);
+ dim0->set_padding_low(0);
+ dim0->set_padding_high(0);
+ dim0->set_window_dilation(1);
+ dim0->set_base_dilation(6);
+
+ auto dim1 = window.add_dimensions();
+ dim1->set_size(2);
+ dim1->set_stride(1);
+ dim1->set_padding_low(2);
+ dim1->set_padding_high(1);
+ dim1->set_window_dilation(1);
+ dim1->set_base_dilation(2);
+ auto inferred_status =
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred_shape = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
+ inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
+ // Dimension order for this test: batch, feature, x0, x1
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
+
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_batch_dimension(3);
+ dnums.set_feature_dimension(2);
+ dnums.add_spatial_dimensions(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
+ dnums.set_kernel_output_feature_dimension(3);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+
+ Window window;
+ auto dim0 = window.add_dimensions();
+ auto dim1 = window.add_dimensions();
+ dim0->set_size(2);
+ dim0->set_stride(1);
+ dim0->set_padding_low(0);
+ dim0->set_padding_high(0);
+ dim1->set_size(3);
+ dim1->set_stride(2);
+ dim1->set_padding_low(1);
+ dim1->set_padding_high(1);
+ auto inferred_status =
+ ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, window, dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_MATCH(inferred_status.status().error_message(),
+ testing::ContainsRegex("each dimension exactly once"));
+}
+
+TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
+ Shape arg = ShapeUtil::MakeShape(F32, {20});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
+ auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply);
+ EXPECT_IS_OK(inferred_status.status());
+ Shape expected = ShapeUtil::MakeShape(S32, {20});
+ EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, Map) {
+ auto inferred_status_r1f32 = ShapeInference::InferMapShape(
+ {&vector_32_, &vector_32_},
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ EXPECT_IS_OK(inferred_status_r1f32.status());
+ EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
+
+ // It's OK to provide a single argument, as long as the applied arity matches
+ // (this degenerates to a Map).
+ auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
+ {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ EXPECT_IS_OK(inferred_status_r1f32_one.status());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
+
+ auto inferred_status_r2s32 = ShapeInference::InferMapShape(
+ {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
+ ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_));
+ EXPECT_IS_OK(inferred_status_r2s32.status());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
+
+ auto no_args_error = ShapeInference::InferMapShape(
+ {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ASSERT_FALSE(no_args_error.ok());
+ ASSERT_MATCH(no_args_error.status().error_message(),
+ testing::ContainsRegex("expects at least one argument"));
+
+ auto args_diff_shapes_error = ShapeInference::InferMapShape(
+ {&vector_32_, &vector_64_},
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ASSERT_FALSE(args_diff_shapes_error.ok());
+ ASSERT_MATCH(
+ args_diff_shapes_error.status().error_message(),
+ testing::ContainsRegex("requires all operands to have the same shape"));
+
+ auto arity_error = ShapeInference::InferMapShape(
+ {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ ASSERT_FALSE(arity_error.ok());
+ ASSERT_MATCH(arity_error.status().error_message(),
+ testing::ContainsRegex("function arity must match"));
+
+ auto output_shape_error = ShapeInference::InferMapShape(
+ {&vector_32_, &vector_32_},
+ ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_));
+ ASSERT_FALSE(output_shape_error.ok());
+ ASSERT_MATCH(output_shape_error.status().error_message(),
+ testing::ContainsRegex("result has to be a scalar"));
+
+ auto param_shape_error = ShapeInference::InferMapShape(
+ {&vector_32_, &vector_32_},
+ ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_));
+ ASSERT_FALSE(param_shape_error.ok());
+ ASSERT_MATCH(param_shape_error.status().error_message(),
+ testing::ContainsRegex("parameter has to be a scalar"));
+
+ auto param_element_type_error = ShapeInference::InferMapShape(
+ {&vector_32_, &vector_32_},
+ ShapeUtil::MakeProgramShape({f32_, s32_}, f32_));
+ ASSERT_FALSE(param_element_type_error.ok());
+ ASSERT_MATCH(param_element_type_error.status().error_message(),
+ testing::ContainsRegex("parameter type has to match argument"));
+
+ Shape arg = ShapeUtil::MakeShape(F32, {20});
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
+ auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply);
+ EXPECT_IS_OK(inferred_status.status());
+ EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
+
+ auto inferred_status_error1 = ShapeInference::InferMapShape(
+ {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_));
+ ASSERT_FALSE(inferred_status_error1.ok());
+ ASSERT_MATCH(inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("arity must match number of arguments"));
+
+ auto inferred_status_error2 = ShapeInference::InferMapShape(
+ {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_));
+ ASSERT_FALSE(inferred_status_error2.ok());
+ ASSERT_MATCH(inferred_status_error2.status().error_message(),
+ testing::ContainsRegex("has to be a scalar"));
+
+ auto inferred_status_error3 = ShapeInference::InferMapShape(
+ {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_));
+ ASSERT_FALSE(inferred_status_error3.ok());
+ ASSERT_MATCH(inferred_status_error3.status().error_message(),
+ testing::ContainsRegex("has to be a scalar"));
+
+ auto inferred_status_error5 = ShapeInference::InferMapShape(
+ {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_));
+ ASSERT_FALSE(inferred_status_error5.ok());
+ ASSERT_MATCH(inferred_status_error5.status().error_message(),
+ testing::ContainsRegex("parameter type has to match argument"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
+ ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
+ /*dimensions_to_reduce=*/{0});
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
+ ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
+ ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{0});
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
+ ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
+ ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{1});
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
+ ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
+ ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{0, 1});
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
+ ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
+ ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{1, 2});
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
+ ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
+ ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{0, 2});
+
+ // Check that the order of dimensions_to_reduce doesn't matter.
+ ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
+ ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{2, 0});
+}
+
+TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
+ ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
+ /*dimensions_to_reduce=*/{0, 1, 2});
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
+ auto inferred_status = ShapeInference::InferReduceShape(
+ ShapeUtil::MakeShape(F32, {5, 3}), f32_, /*dimensions_to_reduce=*/{3, 4},
+ to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_MATCH(inferred_status.status().error_message(),
+ testing::ContainsRegex("out-of-bounds dimension"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
+ auto inferred_status =
+ ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ /*dimensions_to_reduce=*/{0}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_MATCH(inferred_status.status().error_message(),
+ testing::ContainsRegex("take 2 parameters"));
+}
+
+TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
+ ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
+ auto inferred_status =
+ ShapeInference::InferReduceShape(ShapeUtil::MakeShape(F32, {5, 3}), f32_,
+ /*dimensions_to_reduce=*/{0}, to_apply);
+ EXPECT_FALSE(inferred_status.ok());
+ EXPECT_MATCH(inferred_status.status().error_message(),
+ testing::ContainsRegex("first parameter shape differs"));
+}
+
+TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64});
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
+}
+
+TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
+ Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2});
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
+ inferred_status.status().code());
+}
+
+TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
+ Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
+ auto inferred_status =
+ ShapeInference::InferSliceShape(vector_shape, {2}, {4});
+ ASSERT_TRUE(inferred_status.ok());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
+}
+
+TEST_F(ShapeInferenceTest, InferConstIndexShape) {
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
+ auto inferred0_status =
+ ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
+ auto inferred1_status =
+ ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
+ ASSERT_IS_OK(inferred0_status.status());
+ ASSERT_IS_OK(inferred1_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
+ ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferPowShape) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_POW, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferCompareShapeEq) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_EQ, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferCompareShapeGe) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_GE, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferCompareShapeGt) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_GT, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferCompareShapeLe) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_LE, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferCompareShapeLt) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_LT, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, InferCompareShapeNe) {
+ auto ten_floats = ShapeUtil::MakeShape(F32, {10});
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_NE, ten_floats, f32_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
+ inferred_status.ValueOrDie()));
+}
+
+TEST_F(ShapeInferenceTest, BroadcastScalar) {
+ for (auto element_type : {F32, U32, S8}) {
+ const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
+ { // no-op scalar broadcast
+ auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
+ ASSERT_IS_OK(status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
+ }
+ const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
+ { // scalar -> 1d broadcast
+ auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
+ ASSERT_IS_OK(status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
+ }
+ { // no-op 1d broadcast
+ auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
+ ASSERT_IS_OK(status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
+ }
+ const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
+ { // scalar -> 2d broadcast
+ auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
+ ASSERT_IS_OK(status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
+ }
+ { // 1d -> 2d broadcast
+ auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
+ ASSERT_IS_OK(status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
+ }
+ }
+}
+
+// scalar <dot> vector: error
+TEST_F(ShapeInferenceTest, ScalarDotVector) {
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {});
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_MATCH(inferred_status.status().error_message(),
+ testing::ContainsRegex("dot only supports rank"));
+}
+
+// 3D <dot> 2D: error
+TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
+ auto inferred_status = ShapeInference::InferBinaryOpShape(
+ BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {});
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_MATCH(inferred_status.status().error_message(),
+ testing::ContainsRegex("dot only supports rank"));
+}
+
+// vector <dot> vector -> scalar
+TEST_F(ShapeInferenceTest, VectorDotVector) {
+ auto inferred_status =
+ ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
+ auto inferred_status_mismatch =
+ ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {});
+ ASSERT_FALSE(inferred_status_mismatch.ok());
+}
+
+// matrix <dot> vector -> vector
+TEST_F(ShapeInferenceTest, MatrixDotVector) {
+ auto inferred_status = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
+ auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {});
+ ASSERT_FALSE(inferred_status_mismatch.ok());
+}
+
+// vector <dot> matrix -> vector
+TEST_F(ShapeInferenceTest, VectorDotMatrix) {
+ auto inferred_status = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {});
+ ASSERT_IS_OK(inferred_status.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
+ auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {});
+ ASSERT_FALSE(inferred_status_mismatch.ok());
+}
+
+// matrix <dot> matrix -> matrix
+TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
+ auto inferred_status_match = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {});
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(
+ ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
+ << "inferred: "
+ << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
+ << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
+ auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {});
+ ASSERT_FALSE(inferred_status_mismatch.ok());
+}
+
+TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
+ // Test variations of broadcasting a vector for a binary add with a
+ // matrix.
+ const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
+ const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
+ const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
+
+ auto inferred_status_match = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, mat, vec8, {1});
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
+
+ auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, mat, vec8, {0});
+ ASSERT_FALSE(inferred_status_mismatch.ok());
+
+ inferred_status_match = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, mat, vec16, {0});
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
+
+ inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, mat, vec16, {1});
+ ASSERT_FALSE(inferred_status_mismatch.ok());
+}
+
+TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
+ // Test variations of broadcasting a matrix for a binary add with a cube.
+ const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
+ const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
+ const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
+ const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
+
+ auto inferred_status_match = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, cube, matrix8_4, {1, 2});
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
+
+ inferred_status_match = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, cube, matrix16_4, {0, 2});
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
+
+ inferred_status_match = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, cube, matrix16_8, {0, 1});
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
+}
+
+TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
+ // Test various errors with the broadcast argument.
+ const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
+ const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
+ const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
+ const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
+ const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
+
+ // "magical" broadcast rejected
+ auto inferred_status_error1 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor, vec8, {});
+ ASSERT_FALSE(inferred_status_error1.ok());
+ ASSERT_MATCH(inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("automatic"));
+
+ // broadcast_dimension out of bounds for tensor's rank
+ auto inferred_status_error2 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor, vec8, {3});
+ ASSERT_FALSE(inferred_status_error2.ok());
+ ASSERT_MATCH(
+ inferred_status_error2.status().error_message(),
+ testing::ContainsRegex("broadcast dimension number .* too large"));
+
+ // broadcast_dimension doesn't match corresponding dimension
+ auto inferred_status_error3 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor, vec8, {0});
+ ASSERT_FALSE(inferred_status_error3.ok());
+ ASSERT_MATCH(inferred_status_error3.status().error_message(),
+ testing::ContainsRegex("broadcast dimension 0 mismatch"));
+
+ // broadcast_dimensions list too long
+ auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor, matrix8_4, {0, 1, 2});
+ ASSERT_FALSE(inferred_status_error4.ok());
+ ASSERT_MATCH(
+ inferred_status_error4.status().error_message(),
+ testing::ContainsRegex("size of broadcast_dimensions has to match"));
+
+ // there's a dimension above the rank of the tensor
+ auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor, matrix8_4, {3, 0});
+ ASSERT_FALSE(inferred_status_error5.ok());
+ ASSERT_MATCH(
+ inferred_status_error5.status().error_message(),
+ testing::ContainsRegex("broadcast dimension number .* too large"));
+
+ // broadcasting dimensions don't match in this order
+ auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor, matrix8_4, {2, 1});
+ ASSERT_FALSE(inferred_status_error6.ok());
+ ASSERT_MATCH(inferred_status_error6.status().error_message(),
+ testing::ContainsRegex("broadcast dimension 0 mismatch"));
+
+ // The following two tests make sure that broadcasting dimensions are listed
+ // in a proper (strictly increasing) order, even if the lower-rank array
+ // matches the higher-rank array in many different ways.
+ auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {0, 0});
+ ASSERT_FALSE(inferred_status_error7.ok());
+ ASSERT_MATCH(inferred_status_error7.status().error_message(),
+ testing::ContainsRegex("broadcast dimensions order is wrong"));
+
+ auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
+ BinaryOperation::BINOP_ADD, tensor8_8_8, matrix8_8, {1, 0});
+ ASSERT_FALSE(inferred_status_error8.ok());
+ ASSERT_MATCH(inferred_status_error8.status().error_message(),
+ testing::ContainsRegex("broadcast dimensions order is wrong"));
+}
+
+// Tests for the while instruction with proper shapes.
+TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
+ Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
+ ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
+ ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
+ auto inferred_status =
+ ShapeInference::InferWhileShape(cond, body, result_shape);
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
+}
+
+// Tests for the while instruction with wrong shapes.
+TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
+ Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
+ ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
+ ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
+
+ auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
+ auto inferred_status_error1 =
+ ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
+ ASSERT_FALSE(inferred_status_error1.ok());
+ ASSERT_MATCH(inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("condition must take 1 arguments"));
+
+ auto bad_shape_2 =
+ ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
+ auto inferred_status_error2 =
+ ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
+ ASSERT_FALSE(inferred_status_error2.ok());
+ ASSERT_MATCH(inferred_status_error2.status().error_message(),
+ testing::ContainsRegex("body must take 1 arguments"));
+
+ auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
+ auto inferred_status_error3 =
+ ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
+ ASSERT_FALSE(inferred_status_error3.ok());
+ ASSERT_MATCH(inferred_status_error3.status().error_message(),
+ testing::ContainsRegex("condition must return a boolean"));
+
+ auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
+ auto inferred_status_error4 =
+ ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
+ ASSERT_FALSE(inferred_status_error4.ok());
+ ASSERT_MATCH(inferred_status_error4.status().error_message(),
+ testing::ContainsRegex("parameter of condition and body"));
+}
+
+// Tests for the concatenate instruction with proper shapes.
+TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
+ auto inferred_status_1 = ShapeInference::InferConcatOpShape(
+ {&vector_32_, &vector_64_}, /*dimension=*/0);
+ ASSERT_IS_OK(inferred_status_1.status());
+ Shape inferred_1 = inferred_status_1.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
+
+ auto inferred_status_2 = ShapeInference::InferConcatOpShape(
+ {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
+ ASSERT_IS_OK(inferred_status_2.status());
+ Shape inferred_2 = inferred_status_2.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
+
+ auto inferred_status_3 = ShapeInference::InferConcatOpShape(
+ {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
+ ASSERT_IS_OK(inferred_status_3.status());
+ Shape inferred_3 = inferred_status_3.ValueOrDie();
+ ASSERT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
+}
+
+// Tests for the concatenate instruction with wrong shapes.
+TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
+ auto inferred_status_error1 =
+ ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
+ ASSERT_FALSE(inferred_status_error1.ok());
+ ASSERT_MATCH(
+ inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("Concatenate expects at least one argument"));
+
+ auto inferred_status_error2 =
+ ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
+ ASSERT_FALSE(inferred_status_error2.ok());
+ ASSERT_MATCH(inferred_status_error2.status().error_message(),
+ testing::ContainsRegex(
+ "dimension to concatenate along out of bounds: -1"));
+
+ auto inferred_status_error3 =
+ ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
+ ASSERT_FALSE(inferred_status_error3.ok());
+ ASSERT_MATCH(inferred_status_error3.status().error_message(),
+ testing::ContainsRegex(
+ "dimension to concatenate along out of bounds: 1"));
+
+ Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
+ auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
+ {&vector_32_, &tuple}, /*dimension=*/0);
+ ASSERT_FALSE(inferred_status_error4.ok());
+ ASSERT_MATCH(
+ inferred_status_error4.status().error_message(),
+ testing::ContainsRegex(
+ "Expected non-tuple argument for operand of concatenation."));
+
+ const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
+ auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
+ {&vector_32_, &vector_s32}, /*dimension=*/0);
+ ASSERT_FALSE(inferred_status_error5.ok());
+ ASSERT_MATCH(inferred_status_error5.status().error_message(),
+ testing::ContainsRegex(
+ "cannot concatenate arrays with different element types"));
+
+ auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
+ {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
+ ASSERT_FALSE(inferred_status_error6.ok());
+ ASSERT_MATCH(
+ inferred_status_error6.status().error_message(),
+ testing::ContainsRegex("cannot concatenate arrays that differ in "
+ "dimensions other than the one being "
+ "concatenated"));
+}
+
+TEST_F(ShapeInferenceTest, Pad) {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
+ Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
+ // Padding for dimension 0: {low: 0, high: 2, interior: 3}
+ // Padding for dimension 1: {low: 1, high: 5, interior: 0}
+ PaddingConfig padding_config;
+ auto dimension0 = padding_config.add_dimensions();
+ dimension0->set_edge_padding_low(0);
+ dimension0->set_edge_padding_high(2);
+ dimension0->set_interior_padding(3);
+ auto dimension1 = padding_config.add_dimensions();
+ dimension1->set_edge_padding_low(1);
+ dimension1->set_edge_padding_high(5);
+ dimension1->set_interior_padding(0);
+
+ auto inferred_status = ShapeInference::InferPadShape(
+ input_shape, padding_value_shape, padding_config);
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred_shape = inferred_status.ValueOrDie();
+ ASSERT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest, Reverse) {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
+
+ auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
+ ASSERT_IS_OK(inferred_status.status());
+ Shape inferred_shape = inferred_status.ValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
+}
+
+TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
+
+ auto inferred_status_error0 =
+ ShapeInference::InferReverseShape(input_shape, {0, 2});
+ ASSERT_FALSE(inferred_status_error0.ok());
+ ASSERT_MATCH(inferred_status_error0.status().error_message(),
+ testing::ContainsRegex("out-of-bounds"));
+
+ auto inferred_status_error1 =
+ ShapeInference::InferReverseShape(input_shape, {0, -1});
+ ASSERT_FALSE(inferred_status_error1.ok());
+ ASSERT_MATCH(inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("out-of-bounds"));
+
+ auto inferred_status_error2 =
+ ShapeInference::InferReverseShape(input_shape, {0, 0});
+ ASSERT_FALSE(inferred_status_error2.ok());
+ ASSERT_MATCH(inferred_status_error2.status().error_message(),
+ testing::ContainsRegex("duplicated"));
+
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
+ auto inferred_status_error3 =
+ ShapeInference::InferReverseShape(tuple_shape, {0});
+ ASSERT_FALSE(inferred_status_error3.ok());
+ ASSERT_MATCH(inferred_status_error3.status().error_message(),
+ testing::ContainsRegex("Expected non-tuple argument"));
+}
+
+TEST_F(ShapeInferenceTest, Call) {
+ auto inferred_status0 =
+ ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
+ EXPECT_IS_OK(inferred_status0.status());
+ EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
+
+ auto inferred_status1 = ShapeInference::InferCallShape(
+ {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
+ ShapeUtil::MakeProgramShape(
+ {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
+ EXPECT_IS_OK(inferred_status1.status());
+ EXPECT_TRUE(
+ ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
+
+ auto inferred_status_error0 = ShapeInference::InferCallShape(
+ {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
+ EXPECT_FALSE(inferred_status_error0.ok());
+ EXPECT_MATCH(inferred_status_error0.status().error_message(),
+ testing::ContainsRegex("arity must match"));
+
+ auto inferred_status_error1 = ShapeInference::InferCallShape(
+ {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
+ EXPECT_FALSE(inferred_status_error1.ok());
+ EXPECT_MATCH(inferred_status_error1.status().error_message(),
+ testing::ContainsRegex("arity must match"));
+
+ auto inferred_status_error2 = ShapeInference::InferCallShape(
+ {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
+ EXPECT_FALSE(inferred_status_error2.ok());
+ EXPECT_MATCH(inferred_status_error2.status().error_message(),
+ testing::ContainsRegex("parameter must match argument"));
+}
+
+TEST_F(ShapeInferenceTest, Transpose) {
+ Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
+ auto inferred_shape_and_status =
+ ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
+ EXPECT_IS_OK(inferred_shape_and_status);
+ Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
+ EXPECT_TRUE(ShapeUtil::Equal(inferred_shape,
+ ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
new file mode 100644
index 0000000000..cf49fd72b7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+
+#include <set>
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+/* static */ StatusOr<std::unique_ptr<ShapedBuffer>>
+ShapedBuffer::MakeShapedBuffer(const Shape& shape,
+ const perftools::gputools::Platform* platform,
+ int device_ordinal) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Shape must have a layout: %s",
+ ShapeUtil::HumanStringWithLayout(shape).c_str());
+ }
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
+ return WrapUnique(new ShapedBuffer(shape, platform, device_ordinal));
+}
+
+/* static */ StatusOr<std::unique_ptr<ShapedBuffer>>
+ShapedBuffer::MakeArrayShapedBuffer(
+ const Shape& shape, const perftools::gputools::Platform* platform,
+ int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer) {
+ if (ShapeUtil::IsTuple(shape)) {
+ return InvalidArgument("Shape must be an array: %s",
+ ShapeUtil::HumanStringWithLayout(shape).c_str());
+ }
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> shaped_buffer,
+ MakeShapedBuffer(shape, platform, device_ordinal));
+ *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) =
+ 0;
+ *shaped_buffer->mutable_buffers() = {buffer};
+ return std::move(shaped_buffer);
+}
+
+/* static */ StatusOr<std::unique_ptr<ShapedBuffer>>
+ShapedBuffer::MakeUnnestedTupleShapedBuffer(
+ const Shape& shape, const perftools::gputools::Platform* platform,
+ int device_ordinal,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ buffers) {
+ if (!ShapeUtil::IsTuple(shape) || ShapeUtil::IsNestedTuple(shape)) {
+ return InvalidArgument("Shape must be an unnested tuple: %s",
+ ShapeUtil::HumanStringWithLayout(shape).c_str());
+ }
+ if (buffers.size() != ShapeUtil::TupleElementCount(shape)) {
+ return InvalidArgument("Tuple has %lld elements, but %zu buffers given",
+ ShapeUtil::TupleElementCount(shape), buffers.size());
+ }
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> shaped_buffer,
+ MakeShapedBuffer(shape, platform, device_ordinal));
+ TF_CHECK_OK(shaped_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [](const ShapeIndex& index, bool is_leaf,
+ size_t* buffer_element) -> tensorflow::Status {
+ if (is_leaf) {
+ CHECK_EQ(index.size(), 1);
+ *buffer_element = index[0];
+ }
+ return tensorflow::Status::OK();
+ }));
+ shaped_buffer->mutable_buffers()->reserve(buffers.size());
+ for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) {
+ shaped_buffer->mutable_buffers()->push_back(memory_base);
+ }
+ return std::move(shaped_buffer);
+}
+
+ShapedBuffer::ShapedBuffer(const Shape& shape,
+ const perftools::gputools::Platform* platform,
+ int device_ordinal)
+ : shape_(shape),
+ shape_index_to_buffer_entry_(shape),
+ platform_(platform),
+ device_ordinal_(device_ordinal) {}
+
+const perftools::gputools::DeviceMemoryBase& ShapedBuffer::buffer(
+ const ShapeIndex& index) const {
+ // Buffer are only set at the leaves (array elements of the shape).
+ CHECK(shape_index_to_buffer_entry_.IsLeaf(index));
+ return buffers_[shape_index_to_buffer_entry_.element(index)];
+}
+
+perftools::gputools::DeviceMemoryBase* ShapedBuffer::mutable_buffer(
+ const ShapeIndex& index) {
+ // Buffer are only set at the leaves (array elements of the shape).
+ CHECK(shape_index_to_buffer_entry_.IsLeaf(index));
+ return &buffers_[shape_index_to_buffer_entry_.element(index)];
+}
+
+/* static */ StatusOr<std::unique_ptr<ScopedShapedBuffer>>
+ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape,
+ DeviceMemoryAllocator* allocator,
+ int device_ordinal) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Shape must have a layout: %s",
+ ShapeUtil::HumanStringWithLayout(shape).c_str());
+ }
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
+ auto shaped_buffer =
+ WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal));
+
+ // Allocate an appropriate sized buffer for each array element in the shape.
+ TF_RETURN_IF_ERROR(
+ shaped_buffer->shape_index_to_buffer_entry_.ForEachMutableElement(
+ [&shaped_buffer](const ShapeIndex& index, bool is_leaf,
+ size_t* buffer_entry) -> tensorflow::Status {
+ if (is_leaf) {
+ TF_ASSIGN_OR_RETURN(
+ perftools::gputools::DeviceMemoryBase memory_base,
+ shaped_buffer->allocator_->Allocate(
+ shaped_buffer->device_ordinal(),
+ ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape(
+ shaped_buffer->shape(), index))));
+ shaped_buffer->buffers_.push_back(memory_base);
+ *buffer_entry = shaped_buffer->buffers_.size() - 1;
+ }
+ return tensorflow::Status::OK();
+ }));
+ return std::move(shaped_buffer);
+}
+
+ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape,
+ DeviceMemoryAllocator* allocator,
+ int device_ordinal)
+ : ShapedBuffer(shape, allocator->platform(), device_ordinal),
+ allocator_(allocator) {}
+
+ScopedShapedBuffer::~ScopedShapedBuffer() {
+ // Deallocate all non-null buffers. A buffer may appear in more than one spot
+ // in the shape (eg, a tuple with a repeated element) so keep track of what
+ // has been deallocated.
+ std::set<void*> deallocated_opaques;
+ for (perftools::gputools::DeviceMemoryBase& memory_base : buffers_) {
+ if (!memory_base.is_null() &&
+ deallocated_opaques.count(memory_base.opaque()) == 0) {
+ deallocated_opaques.insert(memory_base.opaque());
+ TF_CHECK_OK(
+ this->allocator_->Deallocate(this->device_ordinal(), &memory_base));
+ }
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
new file mode 100644
index 0000000000..aa3b932c4e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -0,0 +1,137 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Class which encapsulates a buffer or set of buffers containing data of a
+// particular XLA shape. Used for zero-copy execution interface for a
+// XLA client running in the same process as the service (LocalClient),
+class ShapedBuffer {
+ public:
+ // Creates a ShapedBuffer of arbitrary shape. All buffer pointers
+ // (DeviceMemoryBase) in the returned ShapedBuffer are initialized to null.
+ static StatusOr<std::unique_ptr<ShapedBuffer>> MakeShapedBuffer(
+ const Shape& shape, const perftools::gputools::Platform* platform,
+ int device_ordinal);
+
+ // Convenience method which creates a ShapedBuffer of array shape (not a
+ // tuple). Its single buffer pointer is set to the given value "buffer". The
+ // given buffer must be large enough to store the given shape as given by
+ // ShapeUtil::ByteSizeOf.
+ static StatusOr<std::unique_ptr<ShapedBuffer>> MakeArrayShapedBuffer(
+ const Shape& shape, const perftools::gputools::Platform* platform,
+ int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer);
+
+ // Convenience method which creates a ShapedBuffer of a non-nested tuple. The
+ // buffer pointers in the return ShapedBuffer are set to the given
+ // "buffers". The size of buffers must match the number of elements in the
+ // tuple shape and be large enough to store their respective shape as given by
+ // ShapeUtil::ByteSizeOf.
+ static StatusOr<std::unique_ptr<ShapedBuffer>> MakeUnnestedTupleShapedBuffer(
+ const Shape& shape, const perftools::gputools::Platform* platform,
+ int device_ordinal,
+ const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ buffers);
+
+ const Shape& shape() const { return shape_; }
+ const perftools::gputools::Platform* platform() const { return platform_; }
+ int device_ordinal() const { return device_ordinal_; }
+
+ // Returns the buffer at the given shape index where index is defined as in
+ // ShapeUtil::GetSubshape.
+ const perftools::gputools::DeviceMemoryBase& buffer(
+ const ShapeIndex& index) const;
+ perftools::gputools::DeviceMemoryBase* mutable_buffer(
+ const ShapeIndex& index);
+
+ // Returns the underlying structure which stores the buffer pointers.
+ const std::vector<perftools::gputools::DeviceMemoryBase>& buffers() const {
+ return buffers_;
+ }
+ std::vector<perftools::gputools::DeviceMemoryBase>* mutable_buffers() {
+ return &buffers_;
+ }
+
+ // Returns the tree of indices which map to buffer pointers.
+ const ShapeTree<size_t>& shape_index_to_buffer_entry() const {
+ return shape_index_to_buffer_entry_;
+ }
+ ShapeTree<size_t>* mutable_shape_index_to_buffer_entry() {
+ return &shape_index_to_buffer_entry_;
+ }
+
+ protected:
+ ShapedBuffer(const Shape& shape,
+ const perftools::gputools::Platform* platform,
+ int device_ordinal);
+
+ // The shape of the device buffer with layout.
+ const Shape shape_;
+
+ // The list of DeviceMemoryBase pointers representing this shape.
+ // Note that there can be a many to one relationship between tuple elements
+ // and buffers. To account for this, shape_index_to_buffer_entry_ allows us
+ // to make from a position in a shape to an index into this list.
+ std::vector<perftools::gputools::DeviceMemoryBase> buffers_;
+
+ // The tree of indices into buffers_.
+ ShapeTree<size_t> shape_index_to_buffer_entry_;
+
+ // The platform the memory is allocated on.
+ const perftools::gputools::Platform* platform_;
+
+ // The device the memory is allocated on.
+ const int device_ordinal_;
+};
+
+// ShapedBuffer derived class which allocates all internal buffers on
+// construction and deallocates the memory when the object is
+// destructed.
+class ScopedShapedBuffer : public ShapedBuffer {
+ public:
+ // Return a new ScopedShapedBuffer of an arbitrary shape. All buffers in the
+ // ScopedShapedBuffers are automatically allocated to exactly the size of
+ // their respective array shape.
+ static StatusOr<std::unique_ptr<ScopedShapedBuffer>> MakeScopedShapedBuffer(
+ const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal);
+
+ // All buffers in the shape are deallocated on destruction.
+ ~ScopedShapedBuffer();
+
+ protected:
+ ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator,
+ int device_ordinal);
+ ScopedShapedBuffer(const ScopedShapedBuffer&) = delete;
+ void operator=(const ScopedShapedBuffer&) = delete;
+
+ DeviceMemoryAllocator* allocator_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
new file mode 100644
index 0000000000..c7f6a13023
--- /dev/null
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -0,0 +1,143 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+
+#include <string>
+#include <utility>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+/* static */ tensorflow::mutex*
+TransferManager::platform_transfer_manager_mutex() {
+ static tensorflow::mutex* m = new tensorflow::mutex;
+ return m;
+}
+
+/* static */ std::map<perftools::gputools::Platform::Id,
+ TransferManager::State>*
+TransferManager::GetPlatformTransferManagers() {
+ static auto* r =
+ new std::map<perftools::gputools::Platform::Id, TransferManager::State>;
+ return r;
+}
+
+/* static */ void TransferManager::RegisterTransferManager(
+ se::Platform::Id platform_id,
+ TransferManagerCreationFunction creation_function) {
+ tensorflow::mutex_lock lock(
+ *TransferManager::platform_transfer_manager_mutex());
+ auto* managers = GetPlatformTransferManagers();
+ CHECK(managers->find(platform_id) == managers->end());
+ (*managers)[platform_id].creation_function = creation_function;
+}
+
+/* static */ StatusOr<TransferManager*> TransferManager::GetForPlatform(
+ const se::Platform* platform) {
+ tensorflow::mutex_lock lock(
+ *TransferManager::platform_transfer_manager_mutex());
+ auto* managers = GetPlatformTransferManagers();
+
+ auto it = managers->find(platform->id());
+ if (it == managers->end()) {
+ return NotFound(
+ "could not find registered transfer manager for platform %s -- check "
+ "target linkage",
+ platform->Name().c_str());
+ }
+
+ if (it->second.manager == nullptr) {
+ // Lazily create the transfer manager the first time it is needed
+ it->second.manager = (*it->second.creation_function)();
+ }
+
+ return it->second.manager;
+}
+
+Status TransferManager::TransferBufferFromDevice(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ int64 size, void* destination) {
+ if (source.size() < size) {
+ return FailedPrecondition(
+ "Source allocation on device not large enough for data tranfer: "
+ "%lld < %lld",
+ source.size(), size);
+ }
+ auto copy_status = executor->SynchronousMemcpyD2H(source, size, destination);
+ if (!copy_status.ok()) {
+ return AddStatus(
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer from device to buffer");
+ }
+ return Status::OK();
+}
+
+Status TransferManager::TransferBufferToDevice(
+ se::StreamExecutor* executor, int64 size, const void* source,
+ se::DeviceMemoryBase* destination) {
+ if (destination->size() < size) {
+ return FailedPrecondition(
+ "Destination allocation on device not large enough for data tranfer: "
+ "%lld < %lld",
+ destination->size(), size);
+ }
+ auto copy_status = executor->SynchronousMemcpyH2D(source, size, destination);
+ if (!copy_status.ok()) {
+ return AddStatus(
+ Status(static_cast<tensorflow::error::Code>(copy_status.code()),
+ copy_status.error_message()),
+ "failed transfer of buffer to device");
+ }
+ return Status::OK();
+}
+
+StatusOr<std::set<se::DeviceMemoryBase>>
+TransferManager::GatherBufferPointersFromTuple(
+ se::StreamExecutor* executor, const se::DeviceMemoryBase& source,
+ const Shape& shape) {
+ TF_RET_CHECK(ShapeUtil::IsTuple(shape));
+
+ std::set<se::DeviceMemoryBase> buffer_pointers;
+ buffer_pointers.insert(source);
+
+ TF_ASSIGN_OR_RETURN(std::vector<se::DeviceMemoryBase> tuple_elements,
+ ShallowCopyTupleFromDevice(executor, source, shape));
+ for (auto i = 0; i < tuple_elements.size(); ++i) {
+ const Shape& element_shape = shape.tuple_shapes(i);
+ if (ShapeUtil::IsTuple(element_shape)) {
+ TF_ASSIGN_OR_RETURN(
+ std::set<se::DeviceMemoryBase> buffer_pointers_in_element,
+ GatherBufferPointersFromTuple(executor, tuple_elements[i],
+ element_shape));
+ buffer_pointers.insert(buffer_pointers_in_element.begin(),
+ buffer_pointers_in_element.end());
+ } else {
+ buffer_pointers.insert(tuple_elements[i]);
+ }
+ }
+ return std::move(buffer_pointers);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
new file mode 100644
index 0000000000..90dc921b7d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -0,0 +1,151 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
+
+#include <map>
+#include <set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// The TransferManager interface lets backends provide platform-specific
+// mechanisms for constructing literals from given device memory handles.
+// This lets each platform customize how literals are transferred to/from the
+// device in terms of padding, leading dimension, etc.
+class TransferManager {
+ public:
+ virtual ~TransferManager() {}
+
+ // Returns the ID of the platform that this transfer manager acts on.
+ virtual perftools::gputools::Platform::Id PlatformId() const = 0;
+
+ // Transfers the region into the provided literal using the provided
+ // executor. device_shape is the shape, including layout, of the data on the
+ // device, while literal_shape will be the shape for the literal. device_shape
+ // and literal_shape must be compatible, but need not have the same layout.
+ virtual Status TransferLiteralFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& region,
+ const Shape& device_shape, const Shape& literal_shape,
+ Literal* literal) = 0;
+
+ // Transfers the given literal into the provided region output parameter,
+ // using the given executor.
+ virtual Status TransferLiteralToDevice(
+ perftools::gputools::StreamExecutor* executor, const Literal& literal,
+ perftools::gputools::DeviceMemoryBase* region) = 0;
+
+ // Transfers the given literal into the Infeed interface of the device,
+ // using the given executor.
+ virtual Status TransferLiteralToInfeed(
+ perftools::gputools::StreamExecutor* executor,
+ const Literal& literal) = 0;
+
+ // Resets the device that the given executor runs on.
+ virtual Status ResetDevice(perftools::gputools::StreamExecutor* executor) = 0;
+
+ // Shallow copy a tuple from the device and create a DeviceMemoryBase object
+ // for each element in the tuple. A DeviceMemoryBase object refers to the
+ // buffer containing the data of that element. The DeviceMemoryBase objects
+ // are returned as a vector.
+ virtual StatusOr<std::vector<perftools::gputools::DeviceMemoryBase>>
+ ShallowCopyTupleFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& source,
+ const Shape& shape) = 0;
+
+ // Returns all buffer pointers that the tuple `source` refers to. Unlike
+ // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested
+ // tuples as well. Also, the returned DeviceMemoryBase objects are
+ // deduplicated.
+ StatusOr<std::set<perftools::gputools::DeviceMemoryBase>>
+ GatherBufferPointersFromTuple(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& source, const Shape& shape);
+
+ // Determines the byte size requirement for the given shape on the underlying
+ // architecture. This will be used to allocate an appropriately sized memory
+ // region for a host-to-device transfer.
+ virtual int64 GetByteSizeRequirement(const Shape& shape) = 0;
+
+ // Transfer a memory block of the given size from the device source into the
+ // 'destination' buffer.
+ //
+ // size is the size to transfer to destination in bytes.
+ virtual Status TransferBufferFromDevice(
+ perftools::gputools::StreamExecutor* executor,
+ const perftools::gputools::DeviceMemoryBase& source, int64 size,
+ void* destination);
+
+ // Transfer a memory block of the given size from 'source' buffer to the given
+ // destination of the device.
+ //
+ // size is the size to transfer from source in bytes.
+ virtual Status TransferBufferToDevice(
+ perftools::gputools::StreamExecutor* executor, int64 size,
+ const void* source, perftools::gputools::DeviceMemoryBase* destination);
+
+ typedef TransferManager* (*TransferManagerCreationFunction)();
+
+ /////
+ // The TransferManager class also serves as a point to register objects for
+ // the various platforms.
+
+ // Registers the TransferManager singleton for the platform kind. This is
+ // assumed to be a singleton, so no ownership is transferred.
+ //
+ // Precondition: a platform kind must not be registered more than once.
+ static void RegisterTransferManager(
+ perftools::gputools::Platform::Id platform_id,
+ TransferManagerCreationFunction transfer_manager);
+
+ // Returns the transfer manager singleton pointer if it is available for the
+ // given platform, or an error status if it is not.
+ static StatusOr<TransferManager*> GetForPlatform(
+ const perftools::gputools::Platform* platform);
+
+ private:
+ // Routine that returns the mutex that guards the
+ // platform-to-transfer manager map. Done as a routine to
+ // ensure correct initialization ordering, since RegisterTransferManager
+ // can be called during program initialization time.
+ static tensorflow::mutex* platform_transfer_manager_mutex();
+
+ // State kept for each kind of TransferManager. Registration functions
+ // set up creation_function, and then we use that to lazily create
+ // "manager" the first time GetForPlatform is invoked for a particular id.
+ struct State {
+ TransferManager* manager = nullptr;
+ TransferManagerCreationFunction creation_function = nullptr;
+ };
+
+ // Map from platform kind to transfer manager singleton.
+ static std::map<perftools::gputools::Platform::Id, State>*
+ GetPlatformTransferManagers();
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSFER_MANAGER_H_
diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc
new file mode 100644
index 0000000000..564111c4f2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc
@@ -0,0 +1,159 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/generic_transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+namespace {
+
+class CpuTransferManagerTest : public ::testing::Test {
+ protected:
+ CpuTransferManagerTest() : transfer_manager_(se::host::kHostPlatformId) {
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId)
+ .ValueOrDie();
+ stream_exec_ =
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie();
+ }
+
+ ~CpuTransferManagerTest() override {}
+
+ se::StreamExecutor* stream_exec_;
+ GenericTransferManager transfer_manager_;
+};
+
+TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) {
+ std::vector<uint8> storage(sizeof(uint32), '\x00');
+ se::DeviceMemoryBase memptr(storage.data(), storage.size());
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<uint32>(42);
+ TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal,
+ &memptr));
+
+ CHECK_EQ(42, *reinterpret_cast<uint32*>(&storage[0]));
+}
+
+TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) {
+ std::vector<uint8> storage(4 * sizeof(float), '\x00');
+ se::DeviceMemoryBase memptr(storage.data(), storage.size());
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::CreateR1<float>({1.25f, 2.5f, -17.0f, -20.125f});
+ TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal,
+ &memptr));
+
+ CHECK_EQ(1.25f, *reinterpret_cast<float*>(&storage[0]));
+ CHECK_EQ(2.5f, *reinterpret_cast<float*>(&storage[sizeof(float)]));
+ CHECK_EQ(-17.0f, *reinterpret_cast<float*>(&storage[2 * sizeof(float)]));
+ CHECK_EQ(-20.125f, *reinterpret_cast<float*>(&storage[3 * sizeof(float)]));
+}
+
+TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) {
+ std::vector<uint8> storage(16, '\x00');
+ se::DeviceMemoryBase memptr(storage.data(), storage.size());
+ const char* str = "0123456789abcdef";
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1U8(str);
+ TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal,
+ &memptr));
+
+ CHECK_EQ('0', storage[0]);
+ CHECK_EQ('8', storage[8]);
+ CHECK_EQ('f', storage[15]);
+}
+
+TEST_F(CpuTransferManagerTest, TransferR0U32FromDevice) {
+ std::vector<uint32> storage(1, 42);
+ se::DeviceMemoryBase memptr(storage.data(),
+ storage.size() * sizeof(storage[0]));
+ Literal literal;
+ const Shape shape = ShapeUtil::MakeShape(U32, {});
+ TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
+ stream_exec_, memptr, shape, shape, &literal));
+
+ LiteralTestUtil::ExpectR0Equal<uint32>(42, literal);
+}
+
+TEST_F(CpuTransferManagerTest, TransferR1F32FromDevice) {
+ std::vector<float> storage{1.25f, 2.5f, -17.0f, -20.125f};
+ se::DeviceMemoryBase memptr(storage.data(),
+ storage.size() * sizeof(storage[0]));
+ Literal literal;
+ const Shape shape = ShapeUtil::MakeShape(F32, {4});
+ TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
+ stream_exec_, memptr, shape, shape, &literal));
+
+ LiteralTestUtil::ExpectR1Equal<float>({1.25, 2.5, -17.0, -20.125}, literal);
+}
+
+TEST_F(CpuTransferManagerTest, TransferR1U8FromDevice) {
+ std::vector<uint8> storage{'k', 'l', 'm', 'n'};
+ se::DeviceMemoryBase memptr(storage.data(),
+ storage.size() * sizeof(storage[0]));
+ Literal literal;
+ const Shape shape = ShapeUtil::MakeShape(U8, {4});
+ TF_CHECK_OK(transfer_manager_.TransferLiteralFromDevice(
+ stream_exec_, memptr, shape, shape, &literal));
+ CHECK_EQ("klmn", literal.u8s());
+}
+
+TEST_F(CpuTransferManagerTest, TransferBufferFromDevice) {
+ std::vector<uint64> storage{1, 5, 42};
+ int64 size = storage.size() * sizeof(storage[0]);
+ se::DeviceMemoryBase memptr(storage.data(), size);
+
+ std::vector<uint64> dest(3, 0);
+ TF_CHECK_OK(transfer_manager_.TransferBufferFromDevice(stream_exec_, memptr,
+ size, dest.data()));
+ ASSERT_EQ(1, dest[0]);
+ ASSERT_EQ(5, dest[1]);
+ ASSERT_EQ(42, dest[2]);
+}
+
+TEST_F(CpuTransferManagerTest, TransferBufferToDevice) {
+ int64 size = 3 * sizeof(uint64);
+ std::vector<uint8> storage(size, 0);
+ se::DeviceMemoryBase memptr(storage.data(), size);
+
+ std::vector<uint64> dest{1, 5, 42};
+ TF_CHECK_OK(transfer_manager_.TransferBufferToDevice(stream_exec_, size,
+ dest.data(), &memptr));
+ std::vector<uint64>* storage64 =
+ reinterpret_cast<std::vector<uint64>*>(&storage);
+ ASSERT_EQ(1, (*storage64)[0]);
+ ASSERT_EQ(5, (*storage64)[1]);
+ ASSERT_EQ(42, (*storage64)[2]);
+}
+
+// TODO(b/24679870): add similar tests for GPUs
+
+} // namespace
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
new file mode 100644
index 0000000000..fb4ff1e68e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/transpose_folding.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+namespace {
+
+bool IsOperandFoldableToDot(const HloInstruction& hlo) {
+ return hlo.IsRank2Transpose() &&
+ hlo.user_count() == 1; // The dot is its only user.
+}
+
+bool CanFoldOperandsIntoDot(
+ const HloInstruction& dot,
+ const TransposeFolding::IsTransposableGemmFn& is_transposable_gemm) {
+ if (HloOpcode::kDot != dot.opcode()) {
+ return false;
+ }
+
+ if (!is_transposable_gemm(dot)) {
+ return false;
+ }
+
+ const HloInstruction* lhs = dot.operand(0);
+ const HloInstruction* rhs = dot.operand(1);
+ bool lhs_foldable = IsOperandFoldableToDot(*lhs);
+ bool rhs_foldable = IsOperandFoldableToDot(*rhs);
+ if (!lhs_foldable && !rhs_foldable) {
+ return false;
+ }
+ return true;
+}
+
+// Folds the operands of `dot` that are foldable transposes. `computation` is
+// the parent HLO computation of `dot`. `module` is the parent HloModule of
+// `computation`.
+//
+// Returns whether the module is changed.
+bool FoldTransposeIntoDot(HloInstruction* dot, HloComputation* computation) {
+ std::vector<HloInstruction*> instructions_to_fuse(1, dot);
+ for (HloInstruction* operand : dot->operands()) {
+ if (IsOperandFoldableToDot(*operand)) {
+ instructions_to_fuse.push_back(operand);
+ }
+ }
+
+ // Early-exit if no operands are foldable.
+ if (instructions_to_fuse.size() == 1) {
+ return false;
+ }
+
+ computation->CreateFusionInstruction(
+ instructions_to_fuse, HloInstruction::FusionKind::kTransposeDot);
+ return true;
+}
+
+} // namespace
+
+TransposeFolding::TransposeFolding(IsTransposableGemmFn is_transposable_gemm)
+ : HloPass("transpose-folding"),
+ is_transposable_gemm_(std::move(is_transposable_gemm)) {}
+
+StatusOr<bool> TransposeFolding::Run(HloModule* module) {
+ // Modifying the graph while traversing is dangerous, so we find all folding
+ // opportunities before actually folding them.
+ HloComputation* entry_computation = module->entry_computation();
+
+ std::vector<HloInstruction*> foldable_dots;
+ auto visit_fn = [this, &foldable_dots](HloInstruction* instruction) {
+ if (CanFoldOperandsIntoDot(*instruction, is_transposable_gemm_)) {
+ foldable_dots.emplace_back(instruction);
+ }
+ return tensorflow::Status::OK();
+ };
+ TF_RETURN_IF_ERROR(entry_computation->root_instruction()->Accept(visit_fn));
+
+ bool changed = false;
+ for (HloInstruction* dot : foldable_dots) {
+ changed |= FoldTransposeIntoDot(dot, entry_computation);
+ }
+ return changed;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transpose_folding.h b/tensorflow/compiler/xla/service/transpose_folding.h
new file mode 100644
index 0000000000..7bec2f2364
--- /dev/null
+++ b/tensorflow/compiler/xla/service/transpose_folding.h
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
+
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass.h"
+
+namespace xla {
+
+// HLO pass that folds transpose operators into Dot operators, where the Dot
+// operator is implemented by a GEMM kernel that can transpose its inputs.
+class TransposeFolding : public HloPass {
+ public:
+ // IsTransposableGemmFn should return true iff the instruction argument is
+ // implemented as a GEMM kernel that supports transposing its arguments.
+ typedef std::function<bool(const HloInstruction&)> IsTransposableGemmFn;
+ explicit TransposeFolding(IsTransposableGemmFn is_transposable_gemm);
+
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ IsTransposableGemmFn is_transposable_gemm_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TRANSPOSE_FOLDING_H_
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
new file mode 100644
index 0000000000..09f932e29e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -0,0 +1,149 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/transpose_folding.h"
+
+#include <memory>
+#include <set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+class TransposeFoldingTest : public ::testing::Test {
+ protected:
+ void FoldTranspose(HloModule* module) {
+ TransposeFolding transpose_folding(gpu::ImplementedAsGemm);
+ EXPECT_IS_OK(transpose_folding.Run(module).status());
+ }
+};
+
+TEST_F(TransposeFoldingTest, FoldTranspose) {
+ auto builder = HloComputation::Builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {2, 3}),
+ /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3}),
+ /*name=*/"y"));
+ HloInstruction* transpose_y =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot,
+ /*lhs=*/x, /*rhs=*/transpose_y));
+
+ HloModule module("test_module");
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build(dot));
+ FoldTranspose(&module);
+
+ // Instructions after folding: x, y, and the fusion.
+ std::set<HloInstruction*> instruction_set;
+ for (auto& instruction : entry_computation->instructions()) {
+ instruction_set.insert(instruction.get());
+ }
+ CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
+ CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
+ CHECK_EQ(1, instruction_set.size())
+ << "entry_computation should contain exactly 3 instructions.";
+ HloInstruction* fusion = *instruction_set.begin();
+ EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
+
+ // The fusion instruction should contain two parameters, one transpose and
+ // one dot.
+ EXPECT_EQ(4, fusion->fused_instructions().size());
+}
+
+TEST_F(TransposeFoldingTest, FoldTransposeConstant) {
+ auto builder = HloComputation::Builder("entry_computation");
+ // 2x1
+ HloInstruction* const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<float>({{1}, {2}})));
+ // 3x2
+ HloInstruction* const1 =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1, 2}, {3, 4}, {5, 6}})));
+ HloInstruction* transpose0 =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0}));
+ HloInstruction* transpose1 =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0}));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot,
+ /*lhs=*/transpose0, /*rhs=*/transpose1));
+
+ HloModule module("test_module");
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build(dot));
+ FoldTranspose(&module);
+
+ for (auto& instruction : entry_computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ CHECK_EQ(2, instruction->operand_count());
+ EXPECT_EQ(const0, instruction->operand(0));
+ EXPECT_EQ(const1, instruction->operand(1));
+ }
+ }
+
+ // The created fusion instruction should contain two parameters, two
+ // transposes (one for each parameter) and one dot.
+ EXPECT_EQ(5,
+ entry_computation->root_instruction()->fused_instructions().size());
+}
+
+TEST_F(TransposeFoldingTest, FuseWithConstantOperands) {
+ auto builder = HloComputation::Builder("entry");
+ // (1.0 + 2.0) * (2.0 - 3.0)
+ HloInstruction* const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ HloInstruction* const2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ HloInstruction* const3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
+ HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
+ const1->shape(), HloOpcode::kAdd, const1, const2));
+ HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
+ const2->shape(), HloOpcode::kSubtract, const2, const3));
+ HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ add->shape(), HloOpcode::kMultiply, add, sub));
+
+ HloModule module("fuse_with_constant_operands");
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build(mul));
+ HloInstruction* call = module.OutlineExpressionFromComputation(
+ {add, sub, mul}, "", entry_computation);
+ EXPECT_EQ(call, entry_computation->root_instruction());
+ HloComputation* callee_computation = call->to_apply();
+ // The arguments to the call should be const1, const2, and const3.
+ EXPECT_MATCH(call->operands(), testing::UnorderedMatcher<HloInstruction*>(
+ const1, const2, const3));
+
+ // The callee should contain 3 parameters and 3 binary operators.
+ EXPECT_EQ(6, callee_computation->instructions().size());
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
new file mode 100644
index 0000000000..0e0c0b02e3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -0,0 +1,495 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+
+#include <ostream>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+string BufferAlias::ToString() const {
+ return tensorflow::strings::StrCat("BufferAlias(", instruction_->name(), "[",
+ tensorflow::str_util::Join(index_, ","),
+ "] => ", buffer_->ToString(), ")");
+}
+
+std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias) {
+ out << buffer_alias.ToString();
+ return out;
+}
+
+bool PointsToSet::IsAmbiguous() const {
+ bool ambiguous = false;
+ TF_CHECK_OK(ForEachElement(
+ [&ambiguous](const ShapeIndex& /*index*/, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& points_to) {
+ ambiguous |= points_to.size() > 1;
+ return Status::OK();
+ }));
+ return ambiguous;
+}
+
+bool PointsToSet::IsDistinct() const {
+ bool distinct = true;
+ std::set<const LogicalBuffer*> all_points_to;
+ TF_CHECK_OK(ForEachElement([&distinct, &all_points_to](
+ const ShapeIndex& /*index*/, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& points_to) {
+ for (auto& buffer : points_to) {
+ if (all_points_to.count(buffer) != 0) {
+ distinct = false;
+ }
+ all_points_to.insert(buffer);
+ }
+ return Status::OK();
+ }));
+ return distinct;
+}
+
+size_t PointsToSet::size() const {
+ // Because pointed-to elements may be duplicated we have to create a flattened
+ // set and return the size.
+ return CreateFlattenedSet().size();
+}
+
+std::set<const LogicalBuffer*> PointsToSet::CreateFlattenedSet() const {
+ std::set<const LogicalBuffer*> flat_set;
+ TF_CHECK_OK(ForEachElement(
+ [&flat_set](const ShapeIndex& /*index*/, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& buffers) {
+ flat_set.insert(buffers.begin(), buffers.end());
+ return Status::OK();
+ }));
+ return flat_set;
+}
+
+bool PointsToSet::ContainsBuffer(const LogicalBuffer& buffer) const {
+ bool found = false;
+ TF_CHECK_OK(ForEachElement([&found, &buffer](
+ const ShapeIndex& /*index*/, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& pointed_to_buffers) {
+ if (!found &&
+ std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
+ &buffer) != pointed_to_buffers.end()) {
+ found = true;
+ }
+ return Status::OK();
+ }));
+ return found;
+}
+
+bool PointsToSet::ContainsBufferAtIndex(const LogicalBuffer& buffer,
+ const ShapeIndex& index) const {
+ const std::vector<const LogicalBuffer*>& pointed_to_buffers = element(index);
+ return std::find(pointed_to_buffers.begin(), pointed_to_buffers.end(),
+ &buffer) != pointed_to_buffers.end();
+}
+
+void PointsToSet::AddPointedToBuffer(const LogicalBuffer& buffer,
+ const ShapeIndex& index) {
+ if (ContainsBufferAtIndex(buffer, index)) {
+ return;
+ }
+ mutable_element(index)->push_back(&buffer);
+}
+
+const std::set<HloInstruction*>& PointsToSet::tuple_sources(
+ const ShapeIndex& index) const {
+ return tuple_sources_.element(index);
+}
+
+void PointsToSet::add_tuple_source(const ShapeIndex& index,
+ HloInstruction* tuple) {
+ tuple_sources_.mutable_element(index)->insert(tuple);
+}
+
+/* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
+TuplePointsToAnalysis::Run(const HloModule* module) {
+ std::unique_ptr<TuplePointsToAnalysis> analysis(
+ new TuplePointsToAnalysis(module));
+ TF_RETURN_IF_ERROR(analysis->Analyze());
+ return std::move(analysis);
+}
+
+Status TuplePointsToAnalysis::Analyze() {
+ points_to_.clear();
+ for (auto& computation : module_->computations()) {
+ TF_RETURN_IF_ERROR(computation->Accept(this));
+ for (auto& instruction : computation->instructions()) {
+ TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction(
+ instruction.get(), &instruction_defined_buffers_[instruction.get()]));
+
+ const PointsToSet& points_to_set = GetPointsToSet(instruction.get());
+ TF_RETURN_IF_ERROR(points_to_set.ForEachElement([this, &instruction](
+ const ShapeIndex& index, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& pointed_to_buffers) {
+ for (const LogicalBuffer* buffer : pointed_to_buffers) {
+ if (buffer_aliases_.count(buffer) == 0) {
+ buffer_aliases_.insert({buffer, std::vector<BufferAlias>()});
+ }
+ buffer_aliases_[buffer].emplace_back(*buffer, instruction.get(),
+ index);
+ }
+ return Status::OK();
+ }));
+ }
+ }
+
+ XLA_VLOG_LINES(3, ToString());
+
+ return Status::OK();
+}
+
+const LogicalBuffer& TuplePointsToAnalysis::NewLogicalBuffer(
+ HloInstruction* instruction, const ShapeIndex& index) {
+ CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
+ logical_buffers_.push_back(
+ MakeUnique<LogicalBuffer>(instruction, index, next_buffer_id_));
+ ++next_buffer_id_;
+ return *logical_buffers_.back();
+}
+
+Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
+ // Create trivial points-to set for instruction. Each points-to set at index i
+ // contains a single element LogicalBuffer(hlo_instruction, i). This indicates
+ // that this instruction is the source of all buffers in its own output.
+ PointsToSet& points_to_set = CreateEmptyPointsToSet(hlo_instruction);
+ TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement(
+ [this, hlo_instruction](const ShapeIndex& index, bool /*is_leaf*/,
+ std::vector<const LogicalBuffer*>* buffers) {
+ const LogicalBuffer& buffer = NewLogicalBuffer(hlo_instruction, index);
+ buffers->push_back(&buffer);
+ return Status::OK();
+ }));
+
+ if (ShapeUtil::IsTuple(hlo_instruction->shape())) {
+ // If the hlo instruction is a tuple-shaped, then trivially the instruction
+ // itself is the source of the tuple.
+ points_to_set.add_tuple_source({}, hlo_instruction);
+ }
+
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleGetTupleElement(
+ HloInstruction* get_tuple_element, HloInstruction* operand) {
+ // GetTupleElement forwards a pointer to a particular element of the tuple
+ // operand.
+ int64 element_index = get_tuple_element->tuple_index();
+
+ PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element);
+ const PointsToSet& operand_points_to_set = *FindOrDie(points_to_, operand);
+
+ // Copy the points-to set (and tuple sources) at index {element_index} of the
+ // operand to the points-to set for this GetTupleElement instruction.
+ TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement([&, this](
+ const ShapeIndex& target_index, bool /*is_leaf*/,
+ std::vector<const LogicalBuffer*>* points_to) {
+ // Construct an index into the operand by prepending element_index to the
+ // index for the GetTupleElement instruction's points-to set.
+ ShapeIndex src_index;
+ src_index.push_back(element_index);
+ for (auto element : target_index) {
+ src_index.push_back(element);
+ }
+
+ *points_to = operand_points_to_set.element(src_index);
+ for (HloInstruction* tuple :
+ operand_points_to_set.tuple_sources(src_index)) {
+ points_to_set.add_tuple_source(target_index, tuple);
+ }
+ return Status::OK();
+ }));
+
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy,
+ HloInstruction* operand) {
+ // A kCopy instruction performs a shallow copy of the operand. The top-level
+ // buffer (index={}) is newly created, but all other buffers (in the case of a
+ // tuple shape) come from the operand
+ PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand);
+ points_to_set.mutable_element(/*index=*/{})->clear();
+ points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}),
+ /*index=*/{});
+
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) {
+ // A kBitcast instruction aliases its operand. That is, the buffer of its
+ // result *is* the buffer of its operand, so just copy the operands points-to
+ // set.
+ CreateCopiedPointsToSet(bitcast, bitcast->operand(0));
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
+ PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple);
+ points_to_set.AddPointedToBuffer(NewLogicalBuffer(tuple, /*index=*/{}),
+ /*index=*/{});
+
+ // A tuple contains references to all input operands and transitively any
+ // references in those operands.
+ for (int64 i = 0; i < operands.size(); ++i) {
+ const PointsToSet& operand_points_to_set =
+ *FindOrDie(points_to_, operands[i]);
+
+ // Copy the points-to set (and tuple sources) of the operand into the
+ // respective subtree of the tuple instructions points-to set.
+ TF_RETURN_IF_ERROR(operand_points_to_set.ForEachElement(
+ [&points_to_set, &operand_points_to_set, i](
+ const ShapeIndex& src_index, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& points_to) {
+ ShapeIndex target_index;
+ target_index.push_back(i);
+ for (auto element : src_index) {
+ target_index.push_back(element);
+ }
+
+ *points_to_set.mutable_element(target_index) = points_to;
+
+ for (HloInstruction* tuple :
+ operand_points_to_set.tuple_sources(src_index)) {
+ points_to_set.add_tuple_source(target_index, tuple);
+ }
+ return Status::OK();
+ }));
+ }
+
+ points_to_set.add_tuple_source({}, tuple);
+
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select,
+ HloInstruction* /*pred*/,
+ HloInstruction* on_true,
+ HloInstruction* on_false) {
+ // Select allocates a new buffer and then shallow copies the on_true or
+ // on_false buffer into this new buffer. Which side is chosen cannot be
+ // determined statically so conservatively set the points-to set to the union
+ // of these on_true and on_false operands.
+ //
+ // First create a copy of the on_true points-to set (and tuple sources), then
+ // add in elements of the on_false points-to set (tuple sources).
+ PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true);
+ const PointsToSet& false_points_to_set = *FindOrDie(points_to_, on_false);
+ TF_RETURN_IF_ERROR(points_to_set.ForEachMutableElement(
+ [&](const ShapeIndex& index, bool /*is_leaf*/,
+ std::vector<const LogicalBuffer*>* buffers) {
+ for (const LogicalBuffer* false_buffer :
+ false_points_to_set.element(index)) {
+ points_to_set.AddPointedToBuffer(*false_buffer, index);
+ }
+
+ for (HloInstruction* tuple : false_points_to_set.tuple_sources(index)) {
+ points_to_set.add_tuple_source(index, tuple);
+ }
+ return Status::OK();
+ }));
+
+ // Select creates a new (top-level) buffer to store its result, so its
+ // respective element in the points-to set should contain only itself.
+ points_to_set.mutable_element({})->clear();
+ points_to_set.AddPointedToBuffer(NewLogicalBuffer(select, /*index=*/{}),
+ /*index=*/{});
+ return Status::OK();
+}
+
+Status TuplePointsToAnalysis::HandleFusion(HloInstruction* fusion) {
+ return ShapeUtil::IsTuple(fusion->shape())
+ ? Unimplemented("HandleFusion with tuple output")
+ : DefaultAction(fusion);
+}
+
+const PointsToSet& TuplePointsToAnalysis::GetPointsToSet(
+ const HloInstruction* hlo_instruction) const {
+ return *FindOrDie(points_to_, hlo_instruction);
+}
+
+PointsToSet& TuplePointsToAnalysis::CreateEmptyPointsToSet(
+ const HloInstruction* instruction) {
+ CHECK_EQ(0, points_to_.count(instruction));
+ points_to_[instruction] = MakeUnique<PointsToSet>(instruction->shape());
+ return *FindOrDie(points_to_, instruction);
+}
+
+bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex(
+ HloInstruction* instruction, const ShapeIndex& index) const {
+ const std::vector<const LogicalBuffer*>& buffers =
+ GetPointsToSet(instruction).element(index);
+ return (buffers.size() == 1 && buffers[0]->instruction() == instruction);
+}
+
+Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const {
+ if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) {
+ return FailedPrecondition(
+ "LogicalBuffer %s is ill-defined: instruction %s does not define a "
+ "buffer at that index",
+ buffer.ToString().c_str(), buffer.instruction()->name().c_str());
+ }
+
+ if (buffer.id() < 0 || buffer.id() >= next_buffer_id_) {
+ return FailedPrecondition(
+ "LogicalBuffer %s is ill-defined: invalid id %lld",
+ buffer.ToString().c_str(), buffer.id());
+ }
+ if (GetBuffer(buffer.id()).instruction() != buffer.instruction() ||
+ GetBuffer(buffer.id()).index() != buffer.index()) {
+ return FailedPrecondition(
+ "LogicalBuffer %s is ill-defined: buffer with same id differs: %s",
+ buffer.ToString().c_str(), GetBuffer(buffer.id()).ToString().c_str());
+ }
+
+ return Status::OK();
+}
+
+const LogicalBuffer& TuplePointsToAnalysis::GetBuffer(
+ LogicalBuffer::Id id) const {
+ CHECK_GE(id, 0);
+ CHECK_LT(id, logical_buffers_.size());
+ return *logical_buffers_[id];
+}
+
+StatusOr<const LogicalBuffer*> TuplePointsToAnalysis::GetBufferDefinedAt(
+ const HloInstruction* instruction, const ShapeIndex& index) const {
+ const std::vector<const LogicalBuffer*>& buffers =
+ GetPointsToSet(instruction).element(index);
+ if (buffers.size() != 1 || buffers[0]->instruction() != instruction) {
+ return FailedPrecondition(
+ "instruction %s does not define buffer at index {%s}",
+ instruction->name().c_str(),
+ tensorflow::str_util::Join(index, ",").c_str());
+ }
+ return buffers[0];
+}
+
+const std::vector<BufferAlias>& TuplePointsToAnalysis::GetBufferAliases(
+ const LogicalBuffer& buffer) const {
+ return buffer_aliases_.at(&buffer);
+}
+
+const std::vector<const LogicalBuffer*>&
+TuplePointsToAnalysis::GetBuffersDefinedByInstruction(
+ const HloInstruction* instruction) const {
+ return instruction_defined_buffers_.at(instruction);
+}
+
+Status TuplePointsToAnalysis::GatherBuffersDefinedByInstruction(
+ const HloInstruction* instruction,
+ std::vector<const LogicalBuffer*>* buffers) {
+ return GetPointsToSet(instruction)
+ .ForEachElement([this, buffers, instruction](
+ const ShapeIndex& index, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& source_buffers) {
+ // Add buffers which 'instruction' is the source of.
+ CHECK(!source_buffers.empty());
+ if (source_buffers.size() == 1 &&
+ source_buffers[0]->instruction() == instruction) {
+ // If this instruction is the source of this buffer the
+ // indices must match.
+ DCHECK(source_buffers[0]->index() == index);
+ buffers->push_back(source_buffers[0]);
+ } else {
+ // If the points-to set includes more than one buffer then
+ // necessarily this instruction did not produce the
+ // buffer.
+ for (const LogicalBuffer* source_buffer : source_buffers) {
+ DCHECK(source_buffer->instruction() != instruction);
+ }
+ }
+ return Status::OK();
+ });
+}
+
+PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet(
+ const HloInstruction* instruction, const HloInstruction* src) {
+ // PointsToSet doesn't have a copy constructor so copy over element-by-element
+ // from src PointsToSet.
+ PointsToSet& dst_points_to_set = CreateEmptyPointsToSet(instruction);
+ const PointsToSet& src_points_to_set = GetPointsToSet(src);
+ TF_CHECK_OK(dst_points_to_set.ForEachMutableElement(
+ [this, &dst_points_to_set, &src_points_to_set](
+ const ShapeIndex& index, bool /*is_leaf*/,
+ std::vector<const LogicalBuffer*>* buffers) {
+ *buffers = src_points_to_set.element(index);
+ for (auto& tuple_source : src_points_to_set.tuple_sources(index)) {
+ dst_points_to_set.add_tuple_source(index, tuple_source);
+ }
+ return Status::OK();
+ }));
+ return *FindOrDie(points_to_, instruction);
+}
+
+string TuplePointsToAnalysis::ToString() const {
+ string output = tensorflow::strings::Printf(
+ "TuplePointsToSet for module %s:\n", module_->name().c_str());
+ for (auto& computation : module_->computations()) {
+ tensorflow::strings::StrAppend(&output, "computation ",
+ computation->name().c_str(), ":\n");
+ for (const HloInstruction* instruction :
+ computation->MakeInstructionPostOrder()) {
+ tensorflow::strings::StrAppend(&output, " instruction ",
+ instruction->ToShortString(), ":\n");
+ const PointsToSet& points_to_set = GetPointsToSet(instruction);
+ TF_CHECK_OK(points_to_set.ForEachElement(
+ [&output](const ShapeIndex& index, bool /*is_leaf*/,
+ const std::vector<const LogicalBuffer*>& points_to) {
+ tensorflow::strings::StrAppend(
+ &output, " {", tensorflow::str_util::Join(index, ","), "}: ",
+ tensorflow::str_util::Join(
+ points_to, ", ",
+ [](string* out, const LogicalBuffer* source) {
+ out->append(source->ToString());
+ }),
+ "\n");
+ return Status::OK();
+ }));
+ }
+ for (auto& buffer : logical_buffers_) {
+ tensorflow::strings::StrAppend(&output, " buffer ", buffer->ToString(),
+ ":\n");
+ for (const BufferAlias& buffer_alias : buffer_aliases_.at(buffer.get())) {
+ tensorflow::strings::StrAppend(&output, " alias ",
+ buffer_alias.ToString(), "\n");
+ }
+ }
+ }
+
+ tensorflow::strings::StrAppend(&output, "LogicalBuffers:\n");
+ for (const auto& buffer : logical_buffers_) {
+ tensorflow::strings::StrAppend(&output, " ", buffer->ToString());
+ }
+ return output;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
new file mode 100644
index 0000000000..7a3eb772d6
--- /dev/null
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -0,0 +1,268 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
+
+#include <stddef.h>
+#include <iosfwd>
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A class describing the source(s) of the Buffer(s) contained in the output of
+// a particular HLO instruction. The structure of PointsToSet mirrors the
+// structure of the instruction's shape which may be an arbitrary tree (eg, a
+// nested tuple). Each node in this tree corresponds to a single buffer in the
+// instruction's output and contains the set of Buffers which might define
+// the corresponding buffer.
+class PointsToSet : public ShapeTree<std::vector<const LogicalBuffer*>> {
+ public:
+ explicit PointsToSet(const Shape& shape)
+ : ShapeTree<std::vector<const LogicalBuffer*>>(shape),
+ tuple_sources_(shape) {}
+
+ // Returns true if any points-to sets for any subshape element is not a
+ // singleton.
+ bool IsAmbiguous() const;
+
+ // Returns true if no LogicalBuffer appears in more than one points-to set of
+ // the shape nodes.
+ bool IsDistinct() const;
+
+ // Returns the total number of different LogicalBuffers contained in this
+ // object. This is equal to CreateFlattenedSet().size().
+ size_t size() const;
+
+ // Creates a set containing the union of all LogicalBuffers contained in the
+ // PointsToSet.
+ std::set<const LogicalBuffer*> CreateFlattenedSet() const;
+
+ // Returns true if the given buffer is in the points-to set at the given
+ // index.
+ bool ContainsBufferAtIndex(const LogicalBuffer& buffer,
+ const ShapeIndex& index) const;
+
+ // Returns true if the given buffer is in the points-to set at any index.
+ bool ContainsBuffer(const LogicalBuffer& buffer) const;
+
+ // Adds the given buffer to the points-to set at the given index. This is a
+ // nop if the buffer already is in the set at that index.
+ void AddPointedToBuffer(const LogicalBuffer& buffer, const ShapeIndex& index);
+
+ // For the subshape at the given index (where index is defined as in
+ // ShapeUtil::GetSubshape) this method returns the set of HLO instructions
+ // which may produce the tuple subshape at that index. For example, given:
+ //
+ // %tuple1 = tuple(...)
+ // %tuple2 = tuple(...)
+ // %select = select(%tuple1, %tuple2)
+ // %nested_tuple = tuple(%select, %tuple1)
+ //
+ // These are the values for tuple_sources() for the PointsToSet of
+ // %nested_tuple:
+ //
+ // tuple_sources({}) = {%nested_tuple}
+ // tuple_sources({0}) = {%tuple1, %tuple2}
+ // tuple_sources({1}) = {%tuple1}
+ //
+ // tuple_sources() at the index of an array shape (not a tuple) returns the
+ // empty set. The instructions in the set returned by tuple_sources
+ // necessarily are either Tuple instructions, constants, or parameters.
+ const std::set<HloInstruction*>& tuple_sources(const ShapeIndex& index) const;
+
+ // Add a tuple source instruction for the given index.
+ void add_tuple_source(const ShapeIndex& index, HloInstruction* tuple);
+
+ private:
+ ShapeTree<std::set<HloInstruction*>> tuple_sources_;
+
+ // PointsToSet contains references (const LogicalBuffer*) to elements within
+ // TuplePointsToAnalysis so disable copying.
+ TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet);
+};
+
+// This class describes a particular subshape in a computation (instruction and
+// shape index) and the logical buffer which may be a source of the subshape
+// value.
+class BufferAlias {
+ public:
+ BufferAlias(const LogicalBuffer& buffer, HloInstruction* instruction,
+ const ShapeIndex& index)
+ : buffer_(&buffer), instruction_(instruction), index_(index) {}
+
+ // Return the logical buffer aliased at the instruction and index.
+ const LogicalBuffer& buffer() const { return *buffer_; }
+
+ // Return the instruction/index of the subshape.
+ HloInstruction* instruction() const { return instruction_; }
+ const ShapeIndex& index() const { return index_; }
+
+ bool operator==(const BufferAlias& other) const {
+ return buffer_ == other.buffer_ && instruction_ == other.instruction_ &&
+ index_ == other.index_;
+ }
+ bool operator!=(const BufferAlias& other) const { return !(*this == other); }
+
+ string ToString() const;
+
+ private:
+ const LogicalBuffer* buffer_;
+ HloInstruction* instruction_;
+ const ShapeIndex index_;
+};
+
+std::ostream& operator<<(std::ostream& out, const BufferAlias& buffer_alias);
+
+// DFS visitor that performs tuple points-to analysis. This analysis determines
+// the potential sources of each buffer in each instruction's output.
+class TuplePointsToAnalysis : public DfsHloVisitorWithDefault {
+ public:
+ static StatusOr<std::unique_ptr<TuplePointsToAnalysis>> Run(
+ const HloModule* module);
+
+ // Return the points-to set of an instruction. This describes the potential
+ // sources of each buffer in the instruction's output.
+ const PointsToSet& GetPointsToSet(
+ const HloInstruction* hlo_instruction) const;
+
+ // Returns the logical buffer with the given ID.
+ const LogicalBuffer& GetBuffer(LogicalBuffer::Id id) const;
+
+ // Returns the buffer defined at the given instruction and index. An error is
+ // returned if no buffer is defined at that point.
+ StatusOr<const LogicalBuffer*> GetBufferDefinedAt(
+ const HloInstruction* instruction, const ShapeIndex& index) const;
+
+ // Return a vector containing all BufferAliases of the given logical buffer
+ // This trivially includes the BufferAlias with same instruction and index as
+ // the logical buffer itself, so the returned vector is never empty. The
+ // buffer alias set is the inverse of the points-to set. That is,
+ // LogicalBuffer B is in the points-to set of instruction I at index N iff
+ // instruction I, index N is a BufferAlias of B.
+ const std::vector<BufferAlias>& GetBufferAliases(
+ const LogicalBuffer& buffer) const;
+
+ // Return a vector containing all logical buffers in the module.
+ const std::vector<std::unique_ptr<LogicalBuffer>>& logical_buffers() const {
+ return logical_buffers_;
+ }
+
+ // Returns a vector of buffers that the instruction produces. Most
+ // instructions produce a single buffer (the top-level buffer), some produce
+ // no buffers (eg bitcast), and some produce more than one buffer (eg,
+ // tuple-shaped parameters).
+ const std::vector<const LogicalBuffer*>& GetBuffersDefinedByInstruction(
+ const HloInstruction* instruction) const;
+
+ // Returns true if the given instruction defines a buffer at the given index.
+ bool InstructionDefinesBufferAtIndex(HloInstruction* instruction,
+ const ShapeIndex& index) const;
+
+ // Returns an OK status if the given buffer is defined by instruction
+ // 'buffer.instruction()' at index 'buffer.index()' and if the given buffer
+ // matches the TuplePointsToAnalysis' LogicalBuffer with 'buffer.id'. Returns
+ // an FailedPrecondition error status otherwise. An example of a LogicalBuffer
+ // which is not defined is a tuple element in a Tuple instruction. In this
+ // case, the Tuple instruction does not define the LogicalBuffer, rather that
+ // index aliases one of its operands.
+ Status VerifyBuffer(const LogicalBuffer& buffer) const;
+
+ Status DefaultAction(HloInstruction* hlo_instruction) override;
+ Status HandleTuple(
+ HloInstruction* tuple,
+ tensorflow::gtl::ArraySlice<HloInstruction*> operands) override;
+ Status HandleGetTupleElement(HloInstruction* get_tuple_element,
+ HloInstruction* operand) override;
+ Status HandleBitcast(HloInstruction* bitcast) override;
+ Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override;
+ Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleSelect(HloInstruction* select, HloInstruction* pred,
+ HloInstruction* on_true,
+ HloInstruction* on_false) override;
+
+ string ToString() const;
+
+ private:
+ explicit TuplePointsToAnalysis(const HloModule* module) : module_(module) {}
+
+ // Perform the analysis. Should be called immediately after constructing the
+ // object and before calling GetPointsToSet.
+ Status Analyze();
+
+ // Create a new logical buffer and return a reference to it. The newly created
+ // buffer is stored in an internal vector of LogicalBuffers and can be
+ // accessed with GetBuffer.
+ const LogicalBuffer& NewLogicalBuffer(HloInstruction* instruction,
+ const ShapeIndex& index);
+
+ // Creates an empty PointsToSet in the points_to_ map for the given
+ // instruction.
+ PointsToSet& CreateEmptyPointsToSet(const HloInstruction* instruction);
+
+ // Creates a PointsToSet in the points_to_ map for 'instruction' which is a
+ // copy of the existing PointsToSet for 'src'.
+ PointsToSet& CreateCopiedPointsToSet(const HloInstruction* instruction,
+ const HloInstruction* src);
+
+ // Adds the buffers defined by the given instruction to the given vector.
+ Status GatherBuffersDefinedByInstruction(
+ const HloInstruction* instruction,
+ std::vector<const LogicalBuffer*>* buffers);
+
+ // The module this analysis is performed on.
+ const HloModule* module_;
+
+ // A map containing a PointsToSet for every HLO instruction.
+ tensorflow::gtl::FlatMap<const HloInstruction*, std::unique_ptr<PointsToSet>>
+ points_to_;
+
+ // A map containing the LogicalBuffers defined by each HLO instruction.
+ std::unordered_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
+ instruction_defined_buffers_;
+
+ std::unordered_map<const LogicalBuffer*, std::vector<BufferAlias>>
+ buffer_aliases_;
+
+ // All logical buffers in the module, indexed by LogicalBuffer::Id. Keep as
+ // vector of std::unique_ptr to keep the underlying pointer values stable.
+ std::vector<std::unique_ptr<LogicalBuffer>> logical_buffers_;
+
+ // The ID of the next logical buffer created.
+ LogicalBuffer::Id next_buffer_id_ = 0;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TuplePointsToAnalysis);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_POINTS_TO_ANALYSIS_H_
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
new file mode 100644
index 0000000000..e4dd4d309e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -0,0 +1,544 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+
+#include <map>
+#include <memory>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class TuplePointsToAnalysisTest : public HloTestBase {
+ protected:
+ // Builds a module with the given entry computation and runs points to
+ // analysis.
+ void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
+ module_.reset(new HloModule(TestName()));
+ module_->AddEntryComputation(std::move(computation));
+ points_to_analysis_ =
+ TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie();
+ }
+
+ // Returns the LogicalBuffer defined at the given instruction and
+ // index. CHECKs if no buffer is defined at that point.
+ const LogicalBuffer* const GetBuffer(const HloInstruction* instruction,
+ const ShapeIndex& index) {
+ const std::vector<const LogicalBuffer*>& pointed_to =
+ points_to_analysis_->GetPointsToSet(instruction).element(index);
+ CHECK_EQ(1, pointed_to.size());
+ CHECK_EQ(instruction, pointed_to[0]->instruction());
+ CHECK(index == pointed_to[0]->index());
+ return pointed_to[0];
+ }
+
+ // Checks that the given points-to set contains exactly (unordered) the given
+ // LogicalBuffers.
+ void ExpectHasBuffers(
+ const std::vector<const LogicalBuffer*>& points_to_set,
+ tensorflow::gtl::ArraySlice<const LogicalBuffer*> buffers) {
+ std::vector<const LogicalBuffer*> vec(buffers.begin(), buffers.end());
+ EXPECT_MATCH(points_to_set, testing::UnorderedElementsAre(vec));
+ }
+
+ // Checks that the given points-to set contains exactly (unordered) the
+ // top-level buffers of the given instructions.
+ void ExpectHasTopLevelBuffers(
+ const std::vector<const LogicalBuffer*>& points_to_set,
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ std::vector<const LogicalBuffer*> buffers;
+ for (auto instruction : instructions) {
+ buffers.push_back(GetBuffer(instruction, /*index=*/{}));
+ }
+ ExpectHasBuffers(points_to_set, buffers);
+ }
+
+ // Overload which takes a std::set instead of a std::vector.
+ void ExpectHasTopLevelBuffers(
+ const std::set<const LogicalBuffer*>& points_to_set,
+ tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+ ExpectHasTopLevelBuffers(std::vector<const LogicalBuffer*>(
+ points_to_set.begin(), points_to_set.end()),
+ instructions);
+ }
+
+ // Checks that the buffer defined at the given instruction and index has
+ // aliases which are exactly (unordered) the given instruction/index pairs.
+ void ExpectHasBufferAliases(
+ const HloInstruction* instruction, const ShapeIndex& index,
+ tensorflow::gtl::ArraySlice<std::pair<HloInstruction*, ShapeIndex>>
+ expected) {
+ const LogicalBuffer* buffer =
+ points_to_analysis_->GetBufferDefinedAt(instruction, index)
+ .ValueOrDie();
+ std::vector<BufferAlias> expected_aliases;
+ for (auto& pair : expected) {
+ expected_aliases.push_back(BufferAlias(*buffer, pair.first, pair.second));
+ }
+ EXPECT_MATCH(points_to_analysis_->GetBufferAliases(*buffer),
+ testing::UnorderedElementsAre(expected_aliases));
+ }
+
+ std::unique_ptr<HloModule> module_;
+ std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
+};
+
+// Expect the given std::set<HloInstruction*> as A contains exactly the given
+// HloInstruction*s as __VA_ARGS__.
+#define EXPECT_ISET(A, ...) \
+ EXPECT_MATCH(testing::SetToVec<HloInstruction*>(A), \
+ testing::UnorderedMatcher<HloInstruction*>(__VA_ARGS__))
+
+TEST_F(TuplePointsToAnalysisTest, SimpleTuple) {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+ EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant1).size());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
+ EXPECT_TRUE(
+ points_to_analysis_->GetPointsToSet(constant1).tuple_sources({}).empty());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
+
+ EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(constant2).size());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
+ EXPECT_TRUE(
+ points_to_analysis_->GetPointsToSet(constant2).tuple_sources({}).empty());
+
+ EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
+ EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
+ tuple);
+
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
+ {constant1, constant2, tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
+
+ const PointsToSet& tuple_points_to_set =
+ points_to_analysis_->GetPointsToSet(tuple);
+ EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
+ *GetBuffer(constant1, {}), {0}));
+ EXPECT_TRUE(tuple_points_to_set.ContainsBufferAtIndex(
+ *GetBuffer(constant2, {}), {1}));
+ EXPECT_FALSE(tuple_points_to_set.ContainsBufferAtIndex(
+ *GetBuffer(constant2, {}), {0}));
+ EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant1, {})));
+ EXPECT_TRUE(tuple_points_to_set.ContainsBuffer(*GetBuffer(constant2, {})));
+}
+
+TEST_F(TuplePointsToAnalysisTest, NestedTuple) {
+ // Create a (nested) tuple containing an inner tuple. The points-to set of the
+ // outer tuple should contain all elements of the points-to set of the inner
+ // tuple.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto inner_tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+
+ auto constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({inner_tuple, constant3}));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(constant1).element({}), {constant1});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(constant2).element({}), {constant2});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(constant3).element({}), {constant3});
+
+ EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(inner_tuple).size());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(inner_tuple).IsAmbiguous());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(inner_tuple).IsDistinct());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(inner_tuple).CreateFlattenedSet(),
+ {constant1, constant2, inner_tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(inner_tuple).element({}),
+ {inner_tuple});
+ EXPECT_ISET(
+ points_to_analysis_->GetPointsToSet(inner_tuple).tuple_sources({}),
+ inner_tuple);
+
+ EXPECT_EQ(5, points_to_analysis_->GetPointsToSet(tuple).size());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
+ {constant1, constant2, constant3, inner_tuple, tuple});
+
+ EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
+ tuple);
+ EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({0}),
+ inner_tuple);
+ EXPECT_TRUE(
+ points_to_analysis_->GetPointsToSet(tuple).tuple_sources({1}).empty());
+
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({0}), {inner_tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({0, 0}), {constant1});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({0, 1}), {constant2});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant3});
+}
+
+TEST_F(TuplePointsToAnalysisTest, GetTupleElement) {
+ // Create a nested tuple, then extract the inner tuple with GetTupleElement.
+ // The points-to set of the GetTupleElement should be the same as the inner
+ // tuple.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto inner_tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+
+ auto constant3 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(3.0)));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({inner_tuple, constant3}));
+
+ auto get_tuple_element = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(inner_tuple->shape(), tuple, 0));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ auto& points_to_set = points_to_analysis_->GetPointsToSet(get_tuple_element);
+ EXPECT_EQ(3, points_to_set.size());
+ EXPECT_FALSE(points_to_set.IsAmbiguous());
+ EXPECT_TRUE(points_to_set.IsDistinct());
+ ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
+ {constant1, constant2, inner_tuple});
+ ExpectHasTopLevelBuffers(points_to_set.element({}), {inner_tuple});
+
+ EXPECT_ISET(points_to_set.tuple_sources({}), inner_tuple);
+}
+
+TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) {
+ // Create a tuple which contains duplicate elements.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant, constant, constant}));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_EQ(2, points_to_analysis_->GetPointsToSet(tuple).size());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsDistinct());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
+ {constant, tuple});
+}
+
+TEST_F(TuplePointsToAnalysisTest, TupleCopy) {
+ // Create a copy (HloOpcode::kCopy) of a tuple. The points to sets should be
+ // the same.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(copy).IsAmbiguous());
+ EXPECT_TRUE(points_to_analysis_->GetPointsToSet(copy).IsDistinct());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
+ {constant1, constant2, tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(copy).element({}), {copy});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(copy).CreateFlattenedSet(),
+ {constant1, constant2, copy});
+}
+
+TEST_F(TuplePointsToAnalysisTest, TupleSelect) {
+ // Select from two different tuples. This should create an ambiguous points to
+ // set containing the union of both sides.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ auto tuple2 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant2, constant2}));
+
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ auto select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
+ EXPECT_EQ(3, points_to_set.size());
+ EXPECT_TRUE(points_to_set.IsAmbiguous());
+ EXPECT_FALSE(points_to_set.IsDistinct());
+ ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
+ ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1, constant2});
+ ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
+ ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
+ {constant1, constant2, select});
+}
+
+TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) {
+ // Create a Select which selects between two tuple parameters. Verify the
+ // points-to sets and tuple sources are properly set.
+ Shape tuple_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {1, 2, 3}), ShapeUtil::MakeShape(U32, {5})});
+
+ auto builder = HloComputation::Builder(TestName());
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param0"));
+ auto param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, tuple_shape, "param1"));
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ auto select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple_shape, HloOpcode::kSelect, pred, param0, param1));
+ auto copy = builder.AddInstruction(
+ HloInstruction::CreateUnary(tuple_shape, HloOpcode::kCopy, select));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ // The points-to set of each element of a tuple parameters should be itself
+ // with the appropriate index.
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({}),
+ {GetBuffer(param0, {})});
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({0}),
+ {GetBuffer(param0, {0})});
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(param0).element({1}),
+ {GetBuffer(param0, {1})});
+
+ // Select's point-to set of its subelements should be the respective
+ // subelements of param0 and param1. The top-level buffer, however, does not
+ // alias as it is created by the select instruction.
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({}),
+ {GetBuffer(select, {})});
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({0}),
+ {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(select).element({1}),
+ {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
+
+ // Copy should be identical to select other than the top-level buffer.
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({}),
+ {GetBuffer(copy, {})});
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({0}),
+ {GetBuffer(param0, {0}), GetBuffer(param1, {0})});
+ ExpectHasBuffers(points_to_analysis_->GetPointsToSet(copy).element({1}),
+ {GetBuffer(param0, {1}), GetBuffer(param1, {1})});
+}
+
+TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) {
+ // Select from two identical tuples. The result should not be ambiguous.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ auto tuple2 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ auto select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
+ EXPECT_EQ(3, points_to_set.size());
+ EXPECT_FALSE(points_to_set.IsAmbiguous());
+ EXPECT_TRUE(points_to_set.IsDistinct());
+ ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
+ ExpectHasTopLevelBuffers(points_to_set.element({0}), {constant1});
+ ExpectHasTopLevelBuffers(points_to_set.element({1}), {constant2});
+ ExpectHasTopLevelBuffers(points_to_set.CreateFlattenedSet(),
+ {constant1, constant2, select});
+}
+
+TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) {
+ // Select from nested tuples. Verify that the nested points-to sets contain
+ // the right values.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto inner_tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ auto inner_tuple2 = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant2, constant2}));
+
+ auto tuple1 =
+ builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple1}));
+ auto tuple2 =
+ builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2}));
+
+ auto pred = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
+ auto select = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ auto& points_to_set = points_to_analysis_->GetPointsToSet(select);
+ EXPECT_EQ(5, points_to_set.size());
+ EXPECT_TRUE(points_to_set.IsAmbiguous());
+ EXPECT_FALSE(points_to_set.IsDistinct());
+
+ // Verify points-to set.
+ ExpectHasTopLevelBuffers(points_to_set.element({}), {select});
+ ExpectHasTopLevelBuffers(points_to_set.element({0}),
+ {inner_tuple1, inner_tuple2});
+ ExpectHasTopLevelBuffers(points_to_set.element({0, 0}),
+ {constant1, constant2});
+ ExpectHasTopLevelBuffers(points_to_set.element({0, 1}), {constant2});
+
+ // Verify tuple sources.
+ EXPECT_ISET(points_to_set.tuple_sources({}), tuple1, tuple2);
+ EXPECT_ISET(points_to_set.tuple_sources({0}), inner_tuple1, inner_tuple2);
+ EXPECT_EQ(0, points_to_set.tuple_sources({0, 0}).size());
+ EXPECT_EQ(0, points_to_set.tuple_sources({0, 1}).size());
+}
+
+TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) {
+ // Bitcast is an alias of its operand. A tuple with a bitcast element should
+ // have the operand of the bitcast in its points-to set.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant2->shape(), HloOpcode::kBitcast, constant2));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({constant1, bitcast}));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ EXPECT_EQ(1, points_to_analysis_->GetPointsToSet(bitcast).size());
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(bitcast).element({}), {constant2});
+ EXPECT_TRUE(
+ points_to_analysis_->GetPointsToSet(bitcast).tuple_sources({}).empty());
+
+ EXPECT_EQ(3, points_to_analysis_->GetPointsToSet(tuple).size());
+ EXPECT_FALSE(points_to_analysis_->GetPointsToSet(tuple).IsAmbiguous());
+ EXPECT_ISET(points_to_analysis_->GetPointsToSet(tuple).tuple_sources({}),
+ tuple);
+
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).CreateFlattenedSet(),
+ {constant1, constant2, tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({}), {tuple});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({0}), {constant1});
+ ExpectHasTopLevelBuffers(
+ points_to_analysis_->GetPointsToSet(tuple).element({1}), {constant2});
+}
+
+TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) {
+ // Construct a tuple constant and kCopy it. Verify the points-to set of the
+ // copy correctly correctly points into the nested elements of the constant.
+ auto builder = HloComputation::Builder(TestName());
+ auto tuple_constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()})));
+ auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
+ tuple_constant->shape(), HloOpcode::kCopy, tuple_constant));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ auto& points_to_set = points_to_analysis_->GetPointsToSet(copy);
+
+ ExpectHasBuffers(points_to_set.element({}), {GetBuffer(copy, {})});
+ ExpectHasBuffers(points_to_set.element({0}),
+ {GetBuffer(tuple_constant, {0})});
+ ExpectHasBuffers(points_to_set.element({1}),
+ {GetBuffer(tuple_constant, {1})});
+}
+
+TEST_F(TuplePointsToAnalysisTest, BufferAliases) {
+ // Create a nested tuple in which individual elements appear multiple
+ // times. Verify buffer alias sets.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0)));
+ auto inner_tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({constant1, constant2}));
+ auto tuple = builder.AddInstruction(
+ HloInstruction::CreateTuple({inner_tuple, constant2}));
+
+ BuildModuleAndRunAnalysis(builder.Build());
+
+ ExpectHasBufferAliases(
+ constant1, /*index=*/{},
+ {{constant1, {}}, {inner_tuple, {0}}, {tuple, {0, 0}}});
+ ExpectHasBufferAliases(
+ constant2, /*index=*/{},
+ {{constant2, {}}, {inner_tuple, {1}}, {tuple, {0, 1}}, {tuple, {1}}});
+ ExpectHasBufferAliases(inner_tuple, /*index=*/{},
+ {{inner_tuple, {}}, {tuple, {0}}});
+ ExpectHasBufferAliases(tuple, /*index=*/{}, {{tuple, {}}});
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
new file mode 100644
index 0000000000..04029c7b01
--- /dev/null
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -0,0 +1,2117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/user_computation.h"
+
+#include <algorithm>
+#include <set>
+#include <utility>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/shape_inference.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+namespace {
+
+HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) {
+ switch (unop) {
+ case UNOP_ABS:
+ return HloOpcode::kAbs;
+ case UNOP_CEIL:
+ return HloOpcode::kCeil;
+ case UNOP_EXP:
+ return HloOpcode::kExp;
+ case UNOP_FLOOR:
+ return HloOpcode::kFloor;
+ case UNOP_LOG:
+ return HloOpcode::kLog;
+ case UNOP_LOGICAL_NOT:
+ return HloOpcode::kLogicalNot;
+ case UNOP_NEGATE:
+ return HloOpcode::kNegate;
+ case UNOP_SIGN:
+ return HloOpcode::kSign;
+ case UNOP_SORT:
+ return HloOpcode::kSort;
+ case UNOP_TANH:
+ return HloOpcode::kTanh;
+ default:
+ LOG(FATAL) << "unhandled operation " << unop;
+ }
+}
+
+HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
+ switch (binop) {
+ case BINOP_DOT:
+ return HloOpcode::kDot;
+ case BINOP_MUL:
+ return HloOpcode::kMultiply;
+ case BINOP_ADD:
+ return HloOpcode::kAdd;
+ case BINOP_SUB:
+ return HloOpcode::kSubtract;
+ case BINOP_INDEX:
+ return HloOpcode::kIndex;
+ case BINOP_DIV:
+ return HloOpcode::kDivide;
+ case BINOP_EQ:
+ return HloOpcode::kEq;
+ case BINOP_GE:
+ return HloOpcode::kGe;
+ case BINOP_GT:
+ return HloOpcode::kGt;
+ case BINOP_LE:
+ return HloOpcode::kLe;
+ case BINOP_LT:
+ return HloOpcode::kLt;
+ case BINOP_NE:
+ return HloOpcode::kNe;
+ case BINOP_MAX:
+ return HloOpcode::kMaximum;
+ case BINOP_MIN:
+ return HloOpcode::kMinimum;
+ case BINOP_POW:
+ return HloOpcode::kPower;
+ case BINOP_REM:
+ return HloOpcode::kRemainder;
+ case BINOP_LOGICAL_OR:
+ return HloOpcode::kLogicalOr;
+ case BINOP_LOGICAL_AND:
+ return HloOpcode::kLogicalAnd;
+ default:
+ LOG(FATAL) << "unhandled operation " << binop;
+ }
+}
+
+HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) {
+ switch (triop) {
+ case TRIOP_CLAMP:
+ return HloOpcode::kClamp;
+ case TRIOP_SELECT:
+ return HloOpcode::kSelect;
+ case TRIOP_UPDATE:
+ return HloOpcode::kUpdate;
+ default:
+ LOG(FATAL) << "unhandled operation " << triop;
+ }
+}
+
+HloOpcode VariadicOperationToHloOpcode(VariadicOperation varop) {
+ switch (varop) {
+ case VAROP_TUPLE:
+ return HloOpcode::kTuple;
+ default:
+ LOG(FATAL) << "unhandled operation " << varop;
+ }
+}
+
+} // namespace
+
+/* static */ StatusOr<std::unique_ptr<UserComputation>>
+UserComputation::MakeWithRemapping(
+ const SessionComputation& session_computation,
+ const ComputationHandle& handle,
+ const std::map<int64, ComputationHandle>& old_to_new) {
+ auto user_computation =
+ MakeUnique<UserComputation>(session_computation.name(), handle);
+ {
+ tensorflow::mutex_lock lock(user_computation->mutex_);
+ user_computation->session_computation_ = session_computation;
+ user_computation->next_handle_value_ =
+ std::max_element(session_computation.requests().begin(),
+ session_computation.requests().end(),
+ [](const std::pair<int64, OperationRequest>& lhs,
+ const std::pair<int64, OperationRequest>& rhs) {
+ return lhs.first < rhs.first;
+ })
+ ->first +
+ 1;
+ TF_RETURN_IF_ERROR(user_computation->RemapEmbeddedComputations(old_to_new));
+ }
+
+ return std::move(user_computation);
+}
+
+UserComputation::UserComputation(const string& name,
+ const ComputationHandle& handle)
+ : name_(name), next_handle_value_(1) {
+ *session_computation_.mutable_computation_handle() = handle;
+ session_computation_.set_name(name);
+}
+
+ComputationDataHandle UserComputation::CreateComputationDataHandle() {
+ ComputationDataHandle handle;
+ handle.set_handle(next_handle_value_);
+ // Handles are used as Version values and *must* be assigned consecutively for
+ // computation versioning to work.
+ next_handle_value_++;
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddParameterInstruction(
+ const ParameterRequest& parameter_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ int64 parameter_number = parameter_request.parameter();
+ if (parameters_.count(parameter_number) != 0) {
+ return InvalidArgument("parameter %lld already registered",
+ parameter_number);
+ }
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ const Shape& validated_shape = parameter_request.shape();
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = validated_shape;
+ *request.mutable_request()->mutable_parameter_request() = parameter_request;
+
+ parameters_[parameter_number] = &request;
+
+ return handle;
+}
+
+Status UserComputation::AddSendInstruction(const SendRequest& send_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ *session_computation_.add_send_requests() = send_request;
+ // Check if the operand of the instruction is valid.
+ TF_RETURN_IF_ERROR(LookupRequest(send_request.operand()).status());
+ return Status::OK();
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddRecvInstruction(
+ const RecvRequest& recv_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ const Shape& shape = recv_request.shape();
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
+ ComputationDataHandle handle = CreateComputationDataHandle();
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_recv_request() = recv_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddPadInstruction(
+ const PadRequest& pad_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(pad_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* padding_value,
+ LookupRequest(pad_request.padding_value()));
+
+ TF_ASSIGN_OR_RETURN(Shape inferred_shape, ShapeInference::InferPadShape(
+ operand->output_shape(),
+ padding_value->output_shape(),
+ pad_request.padding_config()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ *request.mutable_request()->mutable_pad_request() = pad_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction(
+ const ConstantRequest& constant_request) {
+ const Shape& validated_shape = constant_request.literal().shape();
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
+
+ tensorflow::mutex_lock lock(mutex_);
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = validated_shape;
+ *request.mutable_request()->mutable_constant_request() = constant_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
+ const GetTupleElementRequest& get_tuple_element_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(get_tuple_element_request.operand()));
+ Shape element_shape = ShapeUtil::GetTupleElementShape(
+ operand->output_shape(), get_tuple_element_request.index());
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = element_shape;
+ *request.mutable_request()->mutable_get_tuple_element_request() =
+ get_tuple_element_request;
+
+ return handle;
+}
+
+Status UserComputation::AddTraceInstruction(const TraceRequest& trace_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // Verify that the operand index is valid.
+ TF_RETURN_IF_ERROR(LookupRequest(trace_request.operand()).status());
+
+ *session_computation_.add_trace_requests() = trace_request;
+
+ return Status::OK();
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddRngInstruction(
+ const RngRequest& rng_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // Check the number of parameters per RNG distribution.
+ switch (rng_request.distribution()) {
+ case RandomDistribution::RNG_BERNOULLI:
+ if (rng_request.parameter_size() != 1) {
+ return InvalidArgument(
+ "RNG distribution (%s) expects 1 parameters, but got %d",
+ RandomDistribution_Name(rng_request.distribution()).c_str(),
+ rng_request.parameter_size());
+ }
+ break;
+ case RandomDistribution::RNG_NORMAL:
+ case RandomDistribution::RNG_UNIFORM:
+ if (rng_request.parameter_size() != 2) {
+ return InvalidArgument(
+ "RNG distribution (%s) expects 2 parameters, but got %d",
+ RandomDistribution_Name(rng_request.distribution()).c_str(),
+ rng_request.parameter_size());
+ }
+ break;
+ default:
+ LOG(FATAL) << "unhandled distribution " << rng_request.distribution();
+ }
+
+ // Verify that the parameter indices are valid;
+ for (const ComputationDataHandle& param : rng_request.parameter()) {
+ TF_RETURN_IF_ERROR(LookupRequest(param).status());
+ }
+ const Shape& validated_shape = rng_request.shape();
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(validated_shape));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = validated_shape;
+ *request.mutable_request()->mutable_rng_request() = rng_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddMapInstruction(
+ const MapRequest& map_request,
+ const UserComputation& to_apply_computation) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ std::vector<const Shape*> operand_shapes;
+ for (const ComputationDataHandle& handle : map_request.operands()) {
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle));
+ operand_shapes.push_back(&operand->output_shape());
+ }
+
+ VersionedComputationHandle::Version to_apply_version =
+ to_apply_computation.version();
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> to_apply_program_shape,
+ to_apply_computation.ComputeProgramShape(to_apply_version));
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ request.add_embedded_computation_versions(to_apply_version);
+ *request.mutable_request()->mutable_map_request() = map_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddReduceInstruction(
+ const ReduceRequest& reduce_request,
+ const UserComputation& to_apply_computation) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(reduce_request.operand()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
+ LookupRequest(reduce_request.init_value()));
+
+ VersionedComputationHandle::Version to_apply_version =
+ to_apply_computation.version();
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> to_apply_program_shape,
+ to_apply_computation.ComputeProgramShape(to_apply_version));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferReduceShape(
+ operand->output_shape(), init_value->output_shape(),
+ AsInt64Slice(reduce_request.dimensions()), *to_apply_program_shape));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ request.add_embedded_computation_versions(to_apply_version);
+ *request.mutable_request()->mutable_reduce_request() = reduce_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddReduceWindowInstruction(
+ const ReduceWindowRequest& reduce_window_request,
+ const UserComputation& to_apply_computation) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(reduce_window_request.operand()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
+ LookupRequest(reduce_window_request.init_value()));
+
+ VersionedComputationHandle::Version to_apply_version =
+ to_apply_computation.version();
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> to_apply_program_shape,
+ to_apply_computation.ComputeProgramShape(to_apply_version));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferReduceWindowShape(
+ operand->output_shape(), init_value->output_shape(),
+ reduce_window_request.window(), *to_apply_program_shape));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ request.add_embedded_computation_versions(to_apply_version);
+ *request.mutable_request()->mutable_reduce_window_request() =
+ reduce_window_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddSelectAndScatterInstruction(
+ const SelectAndScatterRequest& select_and_scatter_request,
+ const UserComputation& select_computation,
+ const UserComputation& scatter_computation) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(select_and_scatter_request.operand()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* source,
+ LookupRequest(select_and_scatter_request.source()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* init_value,
+ LookupRequest(select_and_scatter_request.init_value()));
+
+ VersionedComputationHandle::Version select_version =
+ select_computation.version();
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> select_program_shape,
+ select_computation.ComputeProgramShape(select_version));
+ VersionedComputationHandle::Version scatter_version =
+ scatter_computation.version();
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> scatter_program_shape,
+ scatter_computation.ComputeProgramShape(scatter_version));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferSelectAndScatterShape(
+ operand->output_shape(), *select_program_shape,
+ select_and_scatter_request.window(), source->output_shape(),
+ init_value->output_shape(), *scatter_program_shape));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ request.add_embedded_computation_versions(select_version);
+ request.add_embedded_computation_versions(scatter_version);
+ *request.mutable_request()->mutable_select_and_scatter_request() =
+ select_and_scatter_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddReverseInstruction(
+ const ReverseRequest& reverse_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(reverse_request.operand()));
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferReverseShape(
+ operand->output_shape(), AsInt64Slice(reverse_request.dimensions())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ *request.mutable_request()->mutable_reverse_request() = reverse_request;
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddWhileInstruction(
+ const WhileRequest& while_request,
+ const UserComputation& condition_computation,
+ const UserComputation& body_computation) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* init,
+ LookupRequest(while_request.init()));
+
+ VersionedComputationHandle::Version condition_version =
+ condition_computation.version();
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> condition_program_shape,
+ condition_computation.ComputeProgramShape(condition_version));
+
+ VersionedComputationHandle::Version body_version = body_computation.version();
+ TF_ASSIGN_OR_RETURN(std::shared_ptr<const ProgramShape> body_program_shape,
+ body_computation.ComputeProgramShape(body_version));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferWhileShape(
+ *condition_program_shape, *body_program_shape, init->output_shape()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ request.add_embedded_computation_versions(condition_version);
+ request.add_embedded_computation_versions(body_version);
+ *request.mutable_request()->mutable_while_request() = while_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddBroadcastInstruction(
+ const BroadcastRequest& broadcast_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // Fetches and validates the operand.
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(broadcast_request.operand()));
+ TF_ASSIGN_OR_RETURN(Shape inferred_shape,
+ ShapeInference::InferBroadcastShape(
+ operand->output_shape(),
+ AsInt64Slice(broadcast_request.broadcast_sizes())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ *request.mutable_request()->mutable_broadcast_request() = broadcast_request;
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddReshapeInstruction(
+ const ReshapeRequest& reshape_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // Fetches and validates the operand.
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(reshape_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferReshapeShape(
+ operand->output_shape(), AsInt64Slice(reshape_request.dimensions()),
+ AsInt64Slice(reshape_request.new_sizes())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ *request.mutable_request()->mutable_reshape_request() = reshape_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddSliceInstruction(
+ const SliceRequest& slice_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(slice_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape new_shape,
+ ShapeInference::InferSliceShape(
+ operand->output_shape(), AsInt64Slice(slice_request.start_indices()),
+ AsInt64Slice(slice_request.limit_indices())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = new_shape;
+ *request.mutable_request()->mutable_slice_request() = slice_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddDynamicSliceInstruction(
+ const DynamicSliceRequest& dynamic_slice_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(dynamic_slice_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* start_indices,
+ LookupRequest(dynamic_slice_request.start_indices()));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape new_shape,
+ ShapeInference::InferDynamicSliceShape(
+ operand->output_shape(), start_indices->output_shape(),
+ AsInt64Slice(dynamic_slice_request.slice_sizes())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = new_shape;
+ *request.mutable_request()->mutable_dynamic_slice_request() =
+ dynamic_slice_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle>
+UserComputation::AddDynamicUpdateSliceInstruction(
+ const DynamicUpdateSliceRequest& dynamic_update_slice_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(dynamic_update_slice_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* update,
+ LookupRequest(dynamic_update_slice_request.update()));
+
+ TF_ASSIGN_OR_RETURN(
+ const OperationRequest* start_indices,
+ LookupRequest(dynamic_update_slice_request.start_indices()));
+
+ TF_ASSIGN_OR_RETURN(Shape new_shape,
+ ShapeInference::InferDynamicUpdateSliceShape(
+ operand->output_shape(), update->output_shape(),
+ start_indices->output_shape()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = new_shape;
+ *request.mutable_request()->mutable_dynamic_update_slice_request() =
+ dynamic_update_slice_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddConcatenateInstruction(
+ const ConcatenateRequest& concatenate_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ std::vector<const Shape*> operand_shapes;
+ for (const ComputationDataHandle& handle : concatenate_request.operands()) {
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle));
+ operand_shapes.push_back(&operand->output_shape());
+ }
+
+ TF_ASSIGN_OR_RETURN(Shape new_shape,
+ ShapeInference::InferConcatOpShape(
+ operand_shapes, concatenate_request.dimension()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = new_shape;
+ *request.mutable_request()->mutable_concatenate_request() =
+ concatenate_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddConvertInstruction(
+ const ConvertRequest& convert_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(convert_request.operand()));
+
+ TF_ASSIGN_OR_RETURN(Shape new_shape, ShapeInference::InferConvertShape(
+ operand->output_shape(),
+ convert_request.new_element_type()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = new_shape;
+ *request.mutable_request()->mutable_convert_request() = convert_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddConvolveInstruction(
+ const ConvolveRequest& convolve_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
+ LookupRequest(convolve_request.lhs()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
+ LookupRequest(convolve_request.rhs()));
+ TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvolveShape(
+ lhs->output_shape(), rhs->output_shape(),
+ convolve_request.window(),
+ convolve_request.dimension_numbers()));
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_convolve_request() = convolve_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddCrossReplicaSumInstruction(
+ const CrossReplicaSumRequest& cross_replica_sum_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(cross_replica_sum_request.operand()));
+ TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape(
+ operand->output_shape()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_cross_replica_sum_request() =
+ cross_replica_sum_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddInfeedInstruction(
+ const InfeedRequest& infeed_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ const Shape& shape = infeed_request.shape();
+ if (ShapeUtil::IsNestedTuple(shape)) {
+ return InvalidArgument("Infeed does not support nested tuple shapes");
+ }
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument("Given shape to Infeed must have a layout");
+ }
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_infeed_request() = infeed_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddCallInstruction(
+ const CallRequest& call_request,
+ const UserComputation& to_apply_computation) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ std::vector<const Shape*> operand_shapes;
+ for (const ComputationDataHandle& handle : call_request.operands()) {
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle));
+ operand_shapes.push_back(&operand->output_shape());
+ }
+
+ VersionedComputationHandle::Version to_apply_version =
+ to_apply_computation.version();
+ TF_ASSIGN_OR_RETURN(
+ std::shared_ptr<const ProgramShape> to_apply_program_shape,
+ to_apply_computation.ComputeProgramShape(to_apply_version));
+ TF_ASSIGN_OR_RETURN(
+ Shape inferred_shape,
+ ShapeInference::InferCallShape(operand_shapes, *to_apply_program_shape));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = inferred_shape;
+ request.add_embedded_computation_versions(to_apply_version);
+ *request.mutable_request()->mutable_call_request() = call_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction(
+ const CustomCallRequest& custom_call_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ for (const ComputationDataHandle& handle : custom_call_request.operands()) {
+ TF_RETURN_IF_ERROR(LookupRequest(handle).status());
+ }
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = custom_call_request.shape();
+ *request.mutable_request()->mutable_custom_call_request() =
+ custom_call_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddUnaryInstruction(
+ const UnaryOpRequest& unary_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand,
+ LookupRequest(unary_request.operand()));
+ TF_ASSIGN_OR_RETURN(
+ Shape shape, ShapeInference::InferUnaryOpShape(unary_request.unop(),
+ operand->output_shape()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_unary_op_request() = unary_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddBinaryInstruction(
+ const BinaryOpRequest& binary_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
+ LookupRequest(binary_request.lhs()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
+ LookupRequest(binary_request.rhs()));
+ TF_ASSIGN_OR_RETURN(
+ Shape shape,
+ ShapeInference::InferBinaryOpShape(
+ binary_request.binop(), lhs->output_shape(), rhs->output_shape(),
+ AsInt64Slice(binary_request.broadcast_dimensions())));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_binary_op_request() = binary_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddTernaryInstruction(
+ const TernaryOpRequest& ternary_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
+ LookupRequest(ternary_request.lhs()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
+ LookupRequest(ternary_request.rhs()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* ehs,
+ LookupRequest(ternary_request.ehs()));
+ TF_ASSIGN_OR_RETURN(Shape shape,
+ ShapeInference::InferTernaryOpShape(
+ ternary_request.triop(), lhs->output_shape(),
+ rhs->output_shape(), ehs->output_shape()));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_ternary_op_request() = ternary_request;
+
+ return handle;
+}
+
+StatusOr<ComputationDataHandle> UserComputation::AddVariadicInstruction(
+ const VariadicOpRequest& variadic_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ std::vector<const Shape*> operand_shapes;
+ for (const ComputationDataHandle& handle : variadic_request.operands()) {
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle));
+ operand_shapes.push_back(&operand->output_shape());
+ }
+
+ TF_ASSIGN_OR_RETURN(Shape shape,
+ ShapeInference::InferVariadicOpShape(
+ variadic_request.varop(), operand_shapes));
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_variadic_op_request() = variadic_request;
+
+ return handle;
+}
+
+StatusOr<Shape> UserComputation::GetShape(const ComputationDataHandle& handle) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookupRequest(handle));
+ return operand->output_shape();
+}
+
+Status UserComputation::SetReturnValue(const ComputationDataHandle& handle) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ if (!(handle.handle() > 0 && handle.handle() < next_handle_value_)) {
+ return InvalidArgument("Invalid handle in SetReturnValue");
+ }
+
+ handle_to_return_ = handle;
+
+ return Status::OK();
+}
+
+VersionedComputationHandle UserComputation::GetVersionedHandle() const {
+ tensorflow::mutex_lock lock(mutex_);
+
+ VersionedComputationHandle versioned_handle;
+ versioned_handle.handle = session_computation_.computation_handle();
+
+ if (handle_to_return_.handle() > 0) {
+ // A specific handle has been requested for the result of the computation.
+ versioned_handle.version = handle_to_return_.handle();
+ } else {
+ // A version value is simply the most recently assigned
+ // ComputationDataHandle value, ie the handle value of the root of the
+ // computation.
+ versioned_handle.version = next_handle_value_ - 1;
+ }
+
+ return versioned_handle;
+}
+
+VersionedComputationHandle UserComputation::GetVersionedHandleAtOperation(
+ const ComputationDataHandle& operation) const {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // The version at which an operation was added is simply the handle value of
+ // the ComputationDataHandle.
+ VersionedComputationHandle versioned_handle;
+ versioned_handle.handle = session_computation_.computation_handle();
+ versioned_handle.version = operation.handle();
+ return versioned_handle;
+}
+
+VersionedComputationHandle::Version UserComputation::version() const {
+ return GetVersionedHandle().version;
+}
+
+StatusOr<std::shared_ptr<const ProgramShape>>
+UserComputation::ComputeProgramShape(
+ VersionedComputationHandle::Version version) const {
+ tensorflow::mutex_lock lock(mutex_);
+
+ CHECK(version > 0 && version < next_handle_value_);
+
+ if (program_shape_ == nullptr || program_shape_version_ != version) {
+ // ProgramShape has not been computed yet, or is for different
+ // version. Compute it now.
+ TF_RETURN_IF_ERROR(CheckParametersAreContiguous(version));
+
+ auto program_shape = MakeUnique<ProgramShape>();
+ for (int64 request_num = 1; request_num <= version; ++request_num) {
+ const OperationRequest& request =
+ session_computation_.requests().at(request_num);
+ if (request.request().op_case() == OpRequest::kParameterRequest) {
+ const ParameterRequest& parameter_request =
+ request.request().parameter_request();
+ int64 param_no = parameter_request.parameter();
+ // Parameters may be out of order so expand ProgramShape parameters
+ // until
+ // it is at least large enough to hold the current parameter number.
+ while (program_shape->parameters_size() <= param_no) {
+ program_shape->add_parameters();
+ program_shape->add_parameter_names();
+ }
+ *program_shape->mutable_parameters(param_no) = request.output_shape();
+ *program_shape->mutable_parameter_names(param_no) =
+ parameter_request.name();
+ }
+ }
+
+ // The root determines the output shape.
+ *program_shape->mutable_result() = GetRoot(version).output_shape();
+ if (ShapeUtil::IsOpaque(program_shape->result())) {
+ return Unimplemented("Computation results cannot be opaque");
+ }
+
+ program_shape_ = std::move(program_shape);
+ program_shape_version_ = version;
+ }
+
+ return program_shape_;
+}
+
+namespace {
+
+// A visitor which checks whether an operation is a compile-time constant. That
+// is, the operation does not depend on any parameter instructions. The visitor
+// walks the computation starting at a given operation and sets is_constant to
+// false iff a parameter or RNG operation is encountered.
+void ConstantVisitor(const SessionComputation& session_computation,
+ const ComputationDataHandle& handle,
+ std::set<int64>* visited, bool* is_constant) {
+ if (visited->count(handle.handle()) != 0 || !*is_constant) {
+ return;
+ }
+
+ const OperationRequest& request =
+ session_computation.requests().at(handle.handle());
+ switch (request.request().op_case()) {
+ case OpRequest::kRngRequest:
+ *is_constant = false;
+ break;
+
+ case OpRequest::kConstantRequest:
+ break;
+
+ case OpRequest::kGetTupleElementRequest: {
+ const GetTupleElementRequest& get_tuple_element_request =
+ request.request().get_tuple_element_request();
+ ConstantVisitor(session_computation, get_tuple_element_request.operand(),
+ visited, is_constant);
+ break;
+ }
+
+ case OpRequest::kSliceRequest: {
+ const SliceRequest& slice_request = request.request().slice_request();
+ ConstantVisitor(session_computation, slice_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kDynamicSliceRequest: {
+ const DynamicSliceRequest& dynamic_slice_request =
+ request.request().dynamic_slice_request();
+ ConstantVisitor(session_computation, dynamic_slice_request.operand(),
+ visited, is_constant);
+ ConstantVisitor(session_computation,
+ dynamic_slice_request.start_indices(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kDynamicUpdateSliceRequest: {
+ const DynamicUpdateSliceRequest& dynamic_update_slice_request =
+ request.request().dynamic_update_slice_request();
+ ConstantVisitor(session_computation,
+ dynamic_update_slice_request.operand(), visited,
+ is_constant);
+ ConstantVisitor(session_computation,
+ dynamic_update_slice_request.update(), visited,
+ is_constant);
+ ConstantVisitor(session_computation,
+ dynamic_update_slice_request.start_indices(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kConcatenateRequest: {
+ const ConcatenateRequest& concatenate_request =
+ request.request().concatenate_request();
+ for (const ComputationDataHandle& handle :
+ concatenate_request.operands()) {
+ ConstantVisitor(session_computation, handle, visited, is_constant);
+ }
+ break;
+ }
+
+ case OpRequest::kConvolveRequest: {
+ const ConvolveRequest& convolve_request =
+ request.request().convolve_request();
+ ConstantVisitor(session_computation, convolve_request.lhs(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, convolve_request.rhs(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kCrossReplicaSumRequest: {
+ // TODO(b/33009255): Implmement constant folding for cross replica sum.
+ *is_constant = false;
+ break;
+ }
+
+ case OpRequest::kInfeedRequest: {
+ *is_constant = false;
+ break;
+ }
+
+ case OpRequest::kCallRequest: {
+ const CallRequest& call_request = request.request().call_request();
+ for (const ComputationDataHandle& handle : call_request.operands()) {
+ ConstantVisitor(session_computation, handle, visited, is_constant);
+ }
+ // TODO(b/32495713): We aren't checking the to_apply computation itself,
+ // so we conservatively say that computations containing the Call op
+ // cannot be constant. We cannot set is_constant=false in other similar
+ // cases since we're already relying on IsConstant to return true.
+ *is_constant = false;
+ break;
+ }
+
+ case OpRequest::kCustomCallRequest: {
+ *is_constant = false;
+ break;
+ }
+
+ case OpRequest::kMapRequest: {
+ const MapRequest& map_request = request.request().map_request();
+ for (const ComputationDataHandle& handle : map_request.operands()) {
+ ConstantVisitor(session_computation, handle, visited, is_constant);
+ }
+ // TODO(b/32495713): We aren't checking the to_apply computation itself.
+ break;
+ }
+
+ case OpRequest::kReduceRequest: {
+ const ReduceRequest& reduce_request = request.request().reduce_request();
+ ConstantVisitor(session_computation, reduce_request.operand(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, reduce_request.init_value(), visited,
+ is_constant);
+ // TODO(b/32495713): We aren't checking the to_apply computation itself.
+ break;
+ }
+
+ case OpRequest::kReduceWindowRequest: {
+ const ReduceWindowRequest& reduce_window_request =
+ request.request().reduce_window_request();
+ ConstantVisitor(session_computation, reduce_window_request.operand(),
+ visited, is_constant);
+ ConstantVisitor(session_computation, reduce_window_request.init_value(),
+ visited, is_constant);
+ // TODO(b/32495713): We aren't checking the to_apply computation itself.
+ break;
+ }
+
+ case OpRequest::kSelectAndScatterRequest: {
+ const SelectAndScatterRequest& select_and_scatter_request =
+ request.request().select_and_scatter_request();
+ ConstantVisitor(session_computation, select_and_scatter_request.operand(),
+ visited, is_constant);
+ ConstantVisitor(session_computation, select_and_scatter_request.source(),
+ visited, is_constant);
+ ConstantVisitor(session_computation,
+ select_and_scatter_request.init_value(), visited,
+ is_constant);
+ // TODO(b/32495713): We aren't checking the select and scatter
+ // computations themselves.
+ break;
+ }
+
+ case OpRequest::kBroadcastRequest: {
+ const BroadcastRequest& broadcast_request =
+ request.request().broadcast_request();
+ ConstantVisitor(session_computation, broadcast_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kReshapeRequest: {
+ const ReshapeRequest& reshape_request =
+ request.request().reshape_request();
+ ConstantVisitor(session_computation, reshape_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kReverseRequest: {
+ const ReverseRequest& reverse_request =
+ request.request().reverse_request();
+ ConstantVisitor(session_computation, reverse_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kPadRequest: {
+ const PadRequest& pad_request = request.request().pad_request();
+ ConstantVisitor(session_computation, pad_request.operand(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, pad_request.padding_value(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kParameterRequest: {
+ *is_constant = false;
+ break;
+ }
+
+ case OpRequest::kConvertRequest: {
+ const ConvertRequest& convert_request =
+ request.request().convert_request();
+ ConstantVisitor(session_computation, convert_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kWhileRequest: {
+ const WhileRequest& while_request = request.request().while_request();
+ ConstantVisitor(session_computation, while_request.init(), visited,
+ is_constant);
+ // TODO(b/32495713): We aren't checking the condition and body
+ // computations themselves.
+ break;
+ }
+
+ case OpRequest::kTernaryOpRequest: {
+ const TernaryOpRequest& ternary_op_request =
+ request.request().ternary_op_request();
+ ConstantVisitor(session_computation, ternary_op_request.lhs(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, ternary_op_request.rhs(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, ternary_op_request.ehs(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kVariadicOpRequest: {
+ const VariadicOpRequest& variadic_op_request =
+ request.request().variadic_op_request();
+ for (const ComputationDataHandle& handle :
+ variadic_op_request.operands()) {
+ ConstantVisitor(session_computation, handle, visited, is_constant);
+ }
+ break;
+ }
+
+ case OpRequest::kUnaryOpRequest: {
+ const UnaryOpRequest& unary_op_request =
+ request.request().unary_op_request();
+ ConstantVisitor(session_computation, unary_op_request.operand(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::kBinaryOpRequest: {
+ const BinaryOpRequest& binary_op_request =
+ request.request().binary_op_request();
+ ConstantVisitor(session_computation, binary_op_request.lhs(), visited,
+ is_constant);
+ ConstantVisitor(session_computation, binary_op_request.rhs(), visited,
+ is_constant);
+ break;
+ }
+
+ case OpRequest::OP_NOT_SET:
+ LOG(FATAL) << "OperationRequest doesn't contain a request";
+
+ default:
+ LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
+ }
+ visited->insert(handle.handle());
+}
+
+} // namespace
+
+StatusOr<bool> UserComputation::IsConstant(
+ const ComputationDataHandle& handle) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ // Verify that the handle is valid.
+ auto operation_status = LookupRequest(handle);
+ if (!operation_status.ok()) {
+ return operation_status.status();
+ }
+
+ bool is_constant = true;
+ std::set<int64> visited;
+ ConstantVisitor(session_computation_, handle, &visited, &is_constant);
+
+ return is_constant;
+}
+
+const OperationRequest& UserComputation::GetRoot(
+ VersionedComputationHandle::Version version) const {
+ CHECK(version > 0 && version < next_handle_value_);
+ return session_computation_.requests().at(version);
+}
+
+std::vector<VersionedComputationHandle>
+UserComputation::GetEmbeddedComputations(
+ VersionedComputationHandle::Version version) const {
+ tensorflow::mutex_lock lock(mutex_);
+
+ std::vector<VersionedComputationHandle> computations;
+ for (const auto& handle_request : session_computation_.requests()) {
+ int64 handle_value = handle_request.first;
+ if (handle_value <= version) {
+ const OperationRequest& request = handle_request.second;
+ switch (request.request().op_case()) {
+ case OpRequest::kCallRequest: {
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ const CallRequest& call_request = request.request().call_request();
+ const VersionedComputationHandle versioned_handle = {
+ call_request.to_apply(),
+ request.embedded_computation_versions(0)};
+ computations.push_back(versioned_handle);
+ break;
+ }
+
+ case OpRequest::kMapRequest: {
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ const MapRequest& map_request = request.request().map_request();
+ const VersionedComputationHandle versioned_handle = {
+ map_request.to_apply(), request.embedded_computation_versions(0)};
+ computations.push_back(versioned_handle);
+ break;
+ }
+
+ case OpRequest::kReduceRequest: {
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ const ReduceRequest& reduce_request =
+ request.request().reduce_request();
+ const VersionedComputationHandle versioned_handle = {
+ reduce_request.to_apply(),
+ request.embedded_computation_versions(0)};
+ computations.push_back(versioned_handle);
+ break;
+ }
+
+ case OpRequest::kReduceWindowRequest: {
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ const ReduceWindowRequest& reduce_window_request =
+ request.request().reduce_window_request();
+ const VersionedComputationHandle versioned_handle = {
+ reduce_window_request.to_apply(),
+ request.embedded_computation_versions(0)};
+ computations.push_back(versioned_handle);
+ break;
+ }
+
+ case OpRequest::kSelectAndScatterRequest: {
+ CHECK_EQ(2, request.embedded_computation_versions_size());
+ const SelectAndScatterRequest& select_and_scatter_request =
+ request.request().select_and_scatter_request();
+ const VersionedComputationHandle select_versioned_handle = {
+ select_and_scatter_request.select(),
+ request.embedded_computation_versions(0)};
+ computations.push_back(select_versioned_handle);
+ const VersionedComputationHandle scatter_versioned_handle = {
+ select_and_scatter_request.scatter(),
+ request.embedded_computation_versions(1)};
+ computations.push_back(scatter_versioned_handle);
+ break;
+ }
+
+ case OpRequest::kWhileRequest: {
+ CHECK_EQ(2, request.embedded_computation_versions_size());
+ const WhileRequest& while_request = request.request().while_request();
+ const VersionedComputationHandle condition_versioned_handle = {
+ while_request.condition(),
+ request.embedded_computation_versions(0)};
+ computations.push_back(condition_versioned_handle);
+ const VersionedComputationHandle body_versioned_handle = {
+ while_request.body(), request.embedded_computation_versions(1)};
+ computations.push_back(body_versioned_handle);
+ break;
+ }
+
+ default:
+ // No embedded computation.
+ break;
+ }
+ }
+ }
+ return computations;
+}
+
+Status UserComputation::RemapEmbeddedComputations(
+ const std::map<int64, ComputationHandle>& old_to_new) {
+ auto update = [&old_to_new](ComputationHandle* to_update) -> Status {
+ int64 old = to_update->handle();
+ auto it = old_to_new.find(old);
+ if (it == old_to_new.end()) {
+ string mapping = tensorflow::str_util::Join(
+ old_to_new, ", ",
+ [](string* out, std::pair<int64, ComputationHandle> element) {
+ tensorflow::strings::Appendf(out, "%lld:%lld", element.first,
+ element.second.handle());
+ });
+ return NotFound(
+ "could not find referenced (old) computation handle in mapping: "
+ "%lld; mapping: {%s}",
+ old, mapping.c_str());
+ }
+ VLOG(2) << "remapping " << old << " to " << it->second.handle();
+ *to_update = it->second;
+ return Status::OK();
+ };
+ TF_RETURN_IF_ERROR(update(session_computation_.mutable_computation_handle()));
+ for (auto& handle_request : *session_computation_.mutable_requests()) {
+ OperationRequest& request = handle_request.second;
+ switch (request.request().op_case()) {
+ case OpRequest::kCallRequest: {
+ TF_RET_CHECK(1 == request.embedded_computation_versions_size());
+ CallRequest* call_request =
+ request.mutable_request()->mutable_call_request();
+ TF_RETURN_IF_ERROR(update(call_request->mutable_to_apply()));
+ break;
+ }
+ case OpRequest::kMapRequest: {
+ TF_RET_CHECK(1 == request.embedded_computation_versions_size());
+ MapRequest* map_request =
+ request.mutable_request()->mutable_map_request();
+ TF_RETURN_IF_ERROR(update(map_request->mutable_to_apply()));
+ break;
+ }
+ case OpRequest::kReduceRequest: {
+ TF_RET_CHECK(1 == request.embedded_computation_versions_size());
+ ReduceRequest* reduce_request =
+ request.mutable_request()->mutable_reduce_request();
+ TF_RETURN_IF_ERROR(update(reduce_request->mutable_to_apply()));
+ break;
+ }
+ case OpRequest::kReduceWindowRequest: {
+ TF_RET_CHECK(1 == request.embedded_computation_versions_size());
+ ReduceWindowRequest* reduce_window_request =
+ request.mutable_request()->mutable_reduce_window_request();
+ TF_RETURN_IF_ERROR(update(reduce_window_request->mutable_to_apply()));
+ break;
+ }
+ case OpRequest::kSelectAndScatterRequest: {
+ TF_RET_CHECK(2 == request.embedded_computation_versions_size());
+ SelectAndScatterRequest* select_and_scatter_request =
+ request.mutable_request()->mutable_select_and_scatter_request();
+ TF_RETURN_IF_ERROR(
+ update(select_and_scatter_request->mutable_select()));
+ TF_RETURN_IF_ERROR(
+ update(select_and_scatter_request->mutable_scatter()));
+ break;
+ }
+ case OpRequest::kWhileRequest: {
+ TF_RET_CHECK(2 == request.embedded_computation_versions_size());
+ WhileRequest* while_request =
+ request.mutable_request()->mutable_while_request();
+ TF_RETURN_IF_ERROR(update(while_request->mutable_condition()));
+ TF_RETURN_IF_ERROR(update(while_request->mutable_body()));
+ break;
+ }
+ default:
+ // No embedded computation.
+ TF_RET_CHECK(0 == request.embedded_computation_versions_size());
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+SessionComputation UserComputation::CloneSessionComputation(
+ VersionedComputationHandle::Version version) const {
+ tensorflow::mutex_lock lock(mutex_);
+ SessionComputation result = session_computation_;
+ // Erase all the requests that exceed the version specified.
+ // There's no lower_bound method on tensorflow::protobuf::Map so we iterate
+ // all the elements.
+ auto it = result.mutable_requests()->begin();
+ while (it != result.mutable_requests()->end()) {
+ if (it->first > version) {
+ it = result.mutable_requests()->erase(it);
+ } else {
+ ++it;
+ }
+ }
+ return result;
+}
+
+StatusOr<const OperationRequest*> UserComputation::LookupRequest(
+ const ComputationDataHandle& handle) const {
+ int64 handle_value = handle.handle();
+ if (session_computation_.requests().count(handle_value) == 0) {
+ return InvalidArgument("no ComputationDataHandle value %lld", handle_value);
+ }
+ return &session_computation_.requests().at(handle_value);
+}
+
+Status UserComputation::CheckParametersAreContiguous(
+ VersionedComputationHandle::Version version) const {
+ TF_RET_CHECK(version > 0 && version < next_handle_value_);
+
+ // Determine number of parameter inputs at the given version.
+ std::map<int64, const ParameterRequest*> parameter_requests;
+ for (int64 request_num = 1; request_num <= version; ++request_num) {
+ const OperationRequest& request =
+ session_computation_.requests().at(request_num);
+
+ if (request.request().op_case() == OpRequest::kParameterRequest) {
+ const ParameterRequest& parameter_request =
+ request.request().parameter_request();
+ // Duplicate parameters should be checked when parameter requests are
+ // added.
+ TF_RET_CHECK(0 ==
+ parameter_requests.count(parameter_request.parameter()));
+ parameter_requests[parameter_request.parameter()] = &parameter_request;
+ }
+ }
+
+ auto program_shape = MakeUnique<ProgramShape>();
+ for (int64 i = 0; i < parameter_requests.size(); ++i) {
+ auto it = parameter_requests.find(i);
+ if (it == parameter_requests.end()) {
+ return FailedPrecondition(
+ "computation %s does not have all its parameters populated "
+ "sequentially, missing parameter %lld",
+ name_.c_str(), i);
+ }
+ }
+
+ return Status::OK();
+}
+
+namespace {
+
+// Helper class which builds an HLO computation from a SessionComputation. To
+// construct the HLO computation, the SessionComputation graph is walked in
+// DFS order lowering each OperationRequest to an HLO instruction.
+class ComputationLowerer {
+ public:
+ static std::unique_ptr<HloComputation> Lower(
+ const string& computation_name,
+ const SessionComputation& session_computation,
+ VersionedComputationHandle::Version version,
+ UserComputation::HloComputationResolver hlo_resolver,
+ bool include_unused_parameters) {
+ ComputationLowerer lowerer(computation_name, session_computation, version,
+ std::move(hlo_resolver));
+ return lowerer.Lower(include_unused_parameters);
+ }
+
+ private:
+ ComputationLowerer(const string& computation_name,
+ const SessionComputation& session_computation,
+ VersionedComputationHandle::Version version,
+ UserComputation::HloComputationResolver hlo_resolver)
+ : hlo_builder_(computation_name),
+ session_computation_(session_computation),
+ version_(version),
+ hlo_resolver_(std::move(hlo_resolver)) {}
+
+ // Build an HLO computation from the SessionComputation at the given
+ // version.
+ std::unique_ptr<HloComputation> Lower(bool include_unused_parameters);
+
+ private:
+ // DFS visitor of the UserComputation operations which lowers the operations
+ // to HLO instructions.
+ HloInstruction* Visit(const ComputationDataHandle& handle,
+ std::map<int64, HloInstruction*>* visited);
+
+ // Resolves a ComputationHandle and Version to a previously lowered
+ // HloComputation using the hlo_resolver_ function.
+ HloComputation* ResolveComputation(
+ const ComputationHandle& handle,
+ VersionedComputationHandle::Version version);
+
+ HloComputation::Builder hlo_builder_;
+ const SessionComputation& session_computation_;
+ const VersionedComputationHandle::Version version_;
+ const UserComputation::HloComputationResolver hlo_resolver_;
+};
+
+std::unique_ptr<HloComputation> ComputationLowerer::Lower(
+ bool include_unused_parameters) {
+ // Map from ComputationDataHandle to HLO instruction. Serves as a record of
+ // which operations have been visited as well as a cache for looking up
+ // ComputationDataHandles as HloInstructions.
+ std::map<int64, HloInstruction*> visited;
+
+ // A version is simply a ComputationDataHandle of the root of the computation
+ // at the time the version was generated. Create a ComputationDataHandle with
+ // this value and pass it to the visitor as the root of the computation to
+ // lower.
+ ComputationDataHandle root_handle;
+ root_handle.set_handle(version_);
+
+ HloInstruction* hlo_root = Visit(root_handle, &visited);
+
+ // A computation may have unused parameters.
+ if (include_unused_parameters) {
+ for (int64 request_num = 1; request_num <= version_; ++request_num) {
+ const OperationRequest& request =
+ session_computation_.requests().at(request_num);
+ if (request.request().op_case() == OpRequest::kParameterRequest &&
+ visited.count(request.output_handle().handle()) == 0) {
+ Visit(request.output_handle(), &visited);
+ }
+ }
+ }
+
+ // Add trace instructions.
+ for (const auto& trace_request : session_computation_.trace_requests()) {
+ if (trace_request.operand().handle() <= version_) {
+ HloInstruction* operand = visited[trace_request.operand().handle()];
+ // Trace instructions cannot be the root of a computation.
+ HloInstruction* trace_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateTrace(trace_request.tag(), operand));
+ operand->set_tracing(trace_instruction);
+ }
+ }
+
+ // Send instructions do not have users, so they are not reachable from the
+ // root instruction. Therefore, explicitly visit all Send requests (and their
+ // operand chains) and add to the builder.
+ for (const auto& send_request : session_computation_.send_requests()) {
+ Visit(send_request.operand(), &visited);
+ HloInstruction* operand = visited[send_request.operand().handle()];
+ HloInstruction* send_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateSend(operand));
+ send_instruction->set_channel_id(send_request.channel_handle().handle());
+ }
+
+ return hlo_builder_.Build(hlo_root);
+}
+
+HloComputation* ComputationLowerer::ResolveComputation(
+ const ComputationHandle& handle,
+ VersionedComputationHandle::Version version) {
+ const VersionedComputationHandle checked_handle = {handle, version};
+ return hlo_resolver_(checked_handle);
+}
+
+HloInstruction* ComputationLowerer::Visit(
+ const ComputationDataHandle& handle,
+ std::map<int64, HloInstruction*>* visited) {
+ if (visited->count(handle.handle()) != 0) {
+ return (*visited)[handle.handle()];
+ }
+
+ const OperationRequest& request =
+ session_computation_.requests().at(handle.handle());
+ HloInstruction* hlo_instruction;
+ switch (request.request().op_case()) {
+ case OpRequest::kRngRequest: {
+ const RngRequest& rng_request = request.request().rng_request();
+ std::vector<HloInstruction*> parameters;
+ for (const ComputationDataHandle& param : rng_request.parameter()) {
+ parameters.push_back(Visit(param, visited));
+ }
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateRng(
+ request.output_shape(), rng_request.distribution(), parameters));
+ break;
+ }
+
+ case OpRequest::kConstantRequest: {
+ const ConstantRequest& constant_request =
+ request.request().constant_request();
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CloneToUnique(constant_request.literal())));
+ break;
+ }
+
+ case OpRequest::kGetTupleElementRequest: {
+ const GetTupleElementRequest& get_tuple_element_request =
+ request.request().get_tuple_element_request();
+ HloInstruction* operand =
+ Visit(get_tuple_element_request.operand(), visited);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateGetTupleElement(
+ request.output_shape(), operand,
+ get_tuple_element_request.index()));
+ break;
+ }
+
+ case OpRequest::kSliceRequest: {
+ const SliceRequest& slice_request = request.request().slice_request();
+ HloInstruction* operand = Visit(slice_request.operand(), visited);
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateSlice(
+ request.output_shape(), operand,
+ AsInt64Slice(slice_request.start_indices()),
+ AsInt64Slice(slice_request.limit_indices())));
+ break;
+ }
+
+ case OpRequest::kDynamicSliceRequest: {
+ const DynamicSliceRequest& dynamic_slice_request =
+ request.request().dynamic_slice_request();
+ HloInstruction* operand = Visit(dynamic_slice_request.operand(), visited);
+ HloInstruction* start_indices =
+ Visit(dynamic_slice_request.start_indices(), visited);
+
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateDynamicSlice(
+ request.output_shape(), operand, start_indices,
+ AsInt64Slice(dynamic_slice_request.slice_sizes())));
+ break;
+ }
+
+ case OpRequest::kDynamicUpdateSliceRequest: {
+ const DynamicUpdateSliceRequest& dynamic_update_slice_request =
+ request.request().dynamic_update_slice_request();
+ HloInstruction* operand =
+ Visit(dynamic_update_slice_request.operand(), visited);
+ HloInstruction* update =
+ Visit(dynamic_update_slice_request.update(), visited);
+ HloInstruction* start_indices =
+ Visit(dynamic_update_slice_request.start_indices(), visited);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
+ request.output_shape(), operand, update, start_indices));
+ break;
+ }
+
+ case OpRequest::kConcatenateRequest: {
+ const ConcatenateRequest& concatenate_request =
+ request.request().concatenate_request();
+ std::vector<HloInstruction*> operands;
+ for (const ComputationDataHandle& handle :
+ concatenate_request.operands()) {
+ HloInstruction* operand = Visit(handle, visited);
+ operands.push_back(operand);
+ }
+ hlo_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateConcatenate(request.output_shape(), operands,
+ concatenate_request.dimension()));
+ break;
+ }
+
+ case OpRequest::kConvolveRequest: {
+ const ConvolveRequest& convolve_request =
+ request.request().convolve_request();
+ HloInstruction* lhs = Visit(convolve_request.lhs(), visited);
+ HloInstruction* rhs = Visit(convolve_request.rhs(), visited);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateConvolve(
+ request.output_shape(), lhs, rhs, convolve_request.window(),
+ convolve_request.dimension_numbers()));
+ break;
+ }
+
+ case OpRequest::kCrossReplicaSumRequest: {
+ const CrossReplicaSumRequest& cross_replica_sum_request =
+ request.request().cross_replica_sum_request();
+ HloInstruction* operand =
+ Visit(cross_replica_sum_request.operand(), visited);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ request.output_shape(), operand));
+ break;
+ }
+
+ case OpRequest::kInfeedRequest: {
+ hlo_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateInfeed(request.output_shape()));
+ break;
+ }
+
+ case OpRequest::kMapRequest: {
+ const MapRequest& map_request = request.request().map_request();
+ std::vector<HloInstruction*> operands;
+ for (const ComputationDataHandle& handle : map_request.operands()) {
+ HloInstruction* operand = Visit(handle, visited);
+ operands.push_back(operand);
+ }
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ VersionedComputationHandle::Version map_version =
+ request.embedded_computation_versions(0);
+ HloComputation* map_computation =
+ ResolveComputation(map_request.to_apply(), map_version);
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateMap(
+ request.output_shape(), operands, map_computation));
+ break;
+ }
+
+ case OpRequest::kReduceRequest: {
+ const ReduceRequest& reduce_request = request.request().reduce_request();
+ HloInstruction* operand = Visit(reduce_request.operand(), visited);
+ HloInstruction* init_value = Visit(reduce_request.init_value(), visited);
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ VersionedComputationHandle::Version reduce_version =
+ request.embedded_computation_versions(0);
+ HloComputation* reduce_computation =
+ ResolveComputation(reduce_request.to_apply(), reduce_version);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateReduce(
+ request.output_shape(), operand, init_value,
+ AsInt64Slice(reduce_request.dimensions()), reduce_computation));
+ break;
+ }
+
+ case OpRequest::kReduceWindowRequest: {
+ const ReduceWindowRequest& reduce_window_request =
+ request.request().reduce_window_request();
+ HloInstruction* operand = Visit(reduce_window_request.operand(), visited);
+ HloInstruction* init_value =
+ Visit(reduce_window_request.init_value(), visited);
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ VersionedComputationHandle::Version reduce_window_version =
+ request.embedded_computation_versions(0);
+ HloComputation* reduce_window_computation = ResolveComputation(
+ reduce_window_request.to_apply(), reduce_window_version);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateReduceWindow(
+ request.output_shape(), operand, init_value,
+ reduce_window_request.window(), reduce_window_computation));
+ break;
+ }
+
+ case OpRequest::kSelectAndScatterRequest: {
+ const SelectAndScatterRequest& select_and_scatter_request =
+ request.request().select_and_scatter_request();
+ HloInstruction* operand =
+ Visit(select_and_scatter_request.operand(), visited);
+ HloInstruction* source =
+ Visit(select_and_scatter_request.source(), visited);
+ HloInstruction* init_value =
+ Visit(select_and_scatter_request.init_value(), visited);
+ CHECK_EQ(2, request.embedded_computation_versions_size());
+ VersionedComputationHandle::Version select_version =
+ request.embedded_computation_versions(0);
+ VersionedComputationHandle::Version scatter_version =
+ request.embedded_computation_versions(1);
+ HloComputation* select_computation = ResolveComputation(
+ select_and_scatter_request.select(), select_version);
+ HloComputation* scatter_computation = ResolveComputation(
+ select_and_scatter_request.scatter(), scatter_version);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateSelectAndScatter(
+ request.output_shape(), operand, select_computation,
+ select_and_scatter_request.window(), source, init_value,
+ scatter_computation));
+ break;
+ }
+
+ case OpRequest::kBroadcastRequest: {
+ const BroadcastRequest& broadcast_request =
+ request.request().broadcast_request();
+ HloInstruction* operand = Visit(broadcast_request.operand(), visited);
+ std::vector<int64> broadcast_dimensions;
+ // The client-level broadcast instruction just appends dimensions on the
+ // left (adds lowest numbered dimensions). The HLO broadcast op is more
+ // flexible and can add new dimensions anywhere. The broadcast_dimensions
+ // maps operand dimensions to dimensions in the broadcast output, so
+ // to append dimensions on the left the broadcast_dimensions should just
+ // be the n highest dimension numbers of the output shape where n is
+ // the number of input dimensions.
+ for (int i = 0; i < ShapeUtil::Rank(operand->shape()); ++i) {
+ broadcast_dimensions.push_back(i +
+ ShapeUtil::Rank(request.output_shape()) -
+ ShapeUtil::Rank(operand->shape()));
+ }
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
+ request.output_shape(), operand, broadcast_dimensions));
+ break;
+ }
+
+ case OpRequest::kReshapeRequest: {
+ const ReshapeRequest& reshape_request =
+ request.request().reshape_request();
+ HloInstruction* operand = Visit(reshape_request.operand(), visited);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateReshape(
+ request.output_shape(),
+ hlo_builder_.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::PermuteDimensions(
+ InversePermutation(
+ AsInt64Slice(reshape_request.dimensions())),
+ operand->shape()),
+ operand, AsInt64Slice(reshape_request.dimensions())))));
+ break;
+ }
+
+ case OpRequest::kReverseRequest: {
+ const ReverseRequest& reverse_request =
+ request.request().reverse_request();
+ HloInstruction* operand = Visit(reverse_request.operand(), visited);
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateReverse(
+ request.output_shape(), operand,
+ AsInt64Slice(reverse_request.dimensions())));
+ break;
+ }
+
+ case OpRequest::kPadRequest: {
+ const PadRequest& pad_request = request.request().pad_request();
+ HloInstruction* operand = Visit(pad_request.operand(), visited);
+ HloInstruction* padding_value =
+ Visit(pad_request.padding_value(), visited);
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreatePad(
+ request.output_shape(), operand, padding_value,
+ pad_request.padding_config()));
+ break;
+ }
+
+ case OpRequest::kRecvRequest: {
+ const RecvRequest& recv_request = request.request().recv_request();
+ hlo_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateRecv(request.output_shape()));
+ hlo_instruction->set_channel_id(recv_request.channel_handle().handle());
+ break;
+ }
+
+ case OpRequest::kParameterRequest: {
+ const ParameterRequest& parameter_request =
+ request.request().parameter_request();
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateParameter(
+ parameter_request.parameter(), request.output_shape(),
+ parameter_request.name()));
+ break;
+ }
+
+ case OpRequest::kConvertRequest: {
+ const ConvertRequest& convert_request =
+ request.request().convert_request();
+ HloInstruction* operand = Visit(convert_request.operand(), visited);
+ hlo_instruction = hlo_builder_.AddInstruction(
+ HloInstruction::CreateConvert(request.output_shape(), operand));
+ break;
+ }
+
+ case OpRequest::kWhileRequest: {
+ const WhileRequest& while_request = request.request().while_request();
+ CHECK_EQ(2, request.embedded_computation_versions_size());
+ VersionedComputationHandle::Version condition_version =
+ request.embedded_computation_versions(0);
+ HloComputation* condition =
+ ResolveComputation(while_request.condition(), condition_version);
+ VersionedComputationHandle::Version body_version =
+ request.embedded_computation_versions(1);
+ HloComputation* body =
+ ResolveComputation(while_request.body(), body_version);
+ HloInstruction* init = Visit(while_request.init(), visited);
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateWhile(
+ request.output_shape(), condition, body, init));
+ break;
+ }
+
+ case OpRequest::kTernaryOpRequest: {
+ const TernaryOpRequest& ternary_op_request =
+ request.request().ternary_op_request();
+ HloInstruction* lhs = Visit(ternary_op_request.lhs(), visited);
+ HloInstruction* rhs = Visit(ternary_op_request.rhs(), visited);
+ HloInstruction* ehs = Visit(ternary_op_request.ehs(), visited);
+ auto hlo_opcode = TernaryOperationToHloOpcode(ternary_op_request.triop());
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateTernary(
+ request.output_shape(), hlo_opcode, lhs, rhs, ehs));
+ break;
+ }
+
+ case OpRequest::kVariadicOpRequest: {
+ const VariadicOpRequest& variadic_op_request =
+ request.request().variadic_op_request();
+ std::vector<HloInstruction*> operands;
+ for (const ComputationDataHandle& handle :
+ variadic_op_request.operands()) {
+ HloInstruction* operand = Visit(handle, visited);
+ operands.push_back(operand);
+ }
+ auto hlo_opcode =
+ VariadicOperationToHloOpcode(variadic_op_request.varop());
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateVariadic(
+ request.output_shape(), hlo_opcode, operands));
+ break;
+ }
+
+ case OpRequest::kCallRequest: {
+ const CallRequest& call_request = request.request().call_request();
+ std::vector<HloInstruction*> operands;
+ for (const ComputationDataHandle& handle : call_request.operands()) {
+ operands.push_back(Visit(handle, visited));
+ }
+ CHECK_EQ(1, request.embedded_computation_versions_size());
+ VersionedComputationHandle::Version call_version =
+ request.embedded_computation_versions(0);
+ HloComputation* call_computation =
+ ResolveComputation(call_request.to_apply(), call_version);
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateCall(
+ request.output_shape(), operands, call_computation));
+ break;
+ }
+
+ case OpRequest::kCustomCallRequest: {
+ const CustomCallRequest& cc_request =
+ request.request().custom_call_request();
+ std::vector<HloInstruction*> operands;
+ for (const ComputationDataHandle& operand : cc_request.operands()) {
+ operands.push_back(Visit(operand, visited));
+ }
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateCustomCall(
+ cc_request.shape(), operands, cc_request.call_target_name()));
+ break;
+ }
+
+ case OpRequest::kUnaryOpRequest: {
+ const UnaryOpRequest& unary_op_request =
+ request.request().unary_op_request();
+ HloInstruction* operand = Visit(unary_op_request.operand(), visited);
+ auto hlo_opcode = UnaryOperationToHloOpcode(unary_op_request.unop());
+ hlo_instruction = hlo_builder_.AddInstruction(HloInstruction::CreateUnary(
+ request.output_shape(), hlo_opcode, operand));
+ break;
+ }
+
+ case OpRequest::kBinaryOpRequest: {
+ const BinaryOpRequest& binary_op_request =
+ request.request().binary_op_request();
+ HloInstruction* lhs = Visit(binary_op_request.lhs(), visited);
+ HloInstruction* rhs = Visit(binary_op_request.rhs(), visited);
+ auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop());
+ if (binary_op_request.broadcast_dimensions_size() > 0) {
+ // Emit a broadcast instruction to perform the "broadcast in dimension"
+ // operation.
+ CHECK_NE(ShapeUtil::Rank(lhs->shape()), ShapeUtil::Rank(rhs->shape()));
+ HloInstruction* operand_to_broadcast =
+ ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs
+ : rhs;
+ Shape broadcast_shape = ShapeUtil::MakeShape(
+ operand_to_broadcast->shape().element_type(),
+ AsInt64Slice(request.output_shape().dimensions()));
+
+ CHECK_EQ(ShapeUtil::Rank(operand_to_broadcast->shape()),
+ binary_op_request.broadcast_dimensions().size());
+ // The broadcast semantics of a client-level binary op broadcast is
+ // identical to the HLO broadcast semantics so the broadcast_dimensions
+ // field can just be passed to the instruction builder.
+ HloInstruction* broadcasted_operand =
+ hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
+ broadcast_shape, operand_to_broadcast,
+ AsInt64Slice(binary_op_request.broadcast_dimensions())));
+
+ lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
+ rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
+ }
+ hlo_instruction =
+ hlo_builder_.AddInstruction(HloInstruction::CreateBinary(
+ request.output_shape(), hlo_opcode, lhs, rhs));
+ break;
+ }
+
+ case OpRequest::OP_NOT_SET:
+ LOG(FATAL) << "OperationRequest doesn't contain a request";
+
+ default:
+ LOG(FATAL) << "Unexpected request type: " << request.request().op_case();
+ }
+ (*visited)[handle.handle()] = hlo_instruction;
+ return hlo_instruction;
+}
+
+} // namespace
+
+StatusOr<std::unique_ptr<HloComputation>> UserComputation::BuildHloComputation(
+ VersionedComputationHandle::Version version,
+ HloComputationResolver hlo_resolver, bool include_unused_parameters) const {
+ tensorflow::mutex_lock lock(mutex_);
+
+ VLOG(2) << "Building HloComputation from UserComputation " << name_
+ << " at version " << version << ". Operation requests:\n"
+ << session_computation_.ShortDebugString();
+
+ std::unique_ptr<HloComputation> hlo_computation = ComputationLowerer::Lower(
+ tensorflow::strings::StrCat(name(), ".v", version), session_computation_,
+ version, std::move(hlo_resolver), include_unused_parameters);
+
+ VLOG(2) << "HloComputation:\n" << hlo_computation->ToString();
+ return std::move(hlo_computation);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
new file mode 100644
index 0000000000..06824b01c7
--- /dev/null
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -0,0 +1,336 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
+
+#include <functional>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A UserComputation is the built-up computation that users create via the
+// XLA Service interface.
+//
+// The XLA service adds instructions to a user computation via this
+// interface. The state of the computation is stored as a SessionComputation
+// proto which holds a record of all operation-building requests received by the
+// XLA service.
+//
+// UserComputations are lowered to HloComputations which are passed to the high
+// level compiler interface.
+class UserComputation {
+ public:
+ // Factory used when restoring a computation from serialized session
+ // computation (computation snapshot) data. Remaps any references to
+ // computation handle via the old_to_new mapping.
+ //
+ // An error will occur if the old_to_new mapping cannot resolve a reference to
+ // a computation that is present in session_computation.
+ static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping(
+ const SessionComputation& session_computation,
+ const ComputationHandle& handle,
+ const std::map<int64, ComputationHandle>& old_to_new);
+
+ // Creates an empty computation with the given name and computation handle.
+ explicit UserComputation(const string& name, const ComputationHandle& handle);
+
+ // Enqueues a parameter-retrieving instruction onto this user computation.
+ // Returns an error status if the parameter number is already registered with
+ // different values.
+ StatusOr<ComputationDataHandle> AddParameterInstruction(
+ const ParameterRequest& parameter_request);
+
+ // Enqueues a pad instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddPadInstruction(
+ const PadRequest& parameter_request);
+
+ // Enqueues a tracing instruction onto this user computation.
+ // Returns an error status if the operand cannot be resolved.
+ Status AddTraceInstruction(const TraceRequest& trace_request);
+
+ // Enqueues a random number generation instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddRngInstruction(
+ const RngRequest& rng_request);
+
+ // Enqueues a unary instruction onto this user computation.
+ // Returns an error status if the operand index is out of bounds.
+ StatusOr<ComputationDataHandle> AddUnaryInstruction(
+ const UnaryOpRequest& unary_request);
+
+ // Enqueues a binary instruction onto this user computation.
+ // Returns an error status if the operand indices are out of bounds.
+ StatusOr<ComputationDataHandle> AddBinaryInstruction(
+ const BinaryOpRequest& binary_request);
+
+ // Enqueues a ternary instruction onto this user computation.
+ // Returns an error status if the operand indices are out of bounds.
+ StatusOr<ComputationDataHandle> AddTernaryInstruction(
+ const TernaryOpRequest& request);
+
+ // Enqueues a variadic instruction onto this user computation.
+ // Returns an error status if the operand indices are out of bounds.
+ StatusOr<ComputationDataHandle> AddVariadicInstruction(
+ const VariadicOpRequest& variadic_request);
+
+ // Enqueues a constant instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddConstantInstruction(
+ const ConstantRequest& constant_request);
+
+ // Enqueues a get tuple element instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddGetTupleElementInstruction(
+ const GetTupleElementRequest& get_tuple_element_request);
+
+ // Enqueues a map instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddMapInstruction(
+ const MapRequest& map_request,
+ const UserComputation& to_apply_computation);
+
+ // Enqueues a convolution instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddConvolveInstruction(
+ const ConvolveRequest& convolve_request);
+
+ // Enqueues a cross replica sum instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
+ const CrossReplicaSumRequest& cross_replica_sum_request);
+
+ // Enqueues an infeed instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddInfeedInstruction(
+ const InfeedRequest& infeed_request);
+
+ // Enqueues a call instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddCallInstruction(
+ const CallRequest& call_request,
+ const UserComputation& to_apply_computation);
+
+ // Enqueues a custom call instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddCustomCallInstruction(
+ const CustomCallRequest& custom_call_request);
+
+ // Enqueues a broadcast instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddBroadcastInstruction(
+ const BroadcastRequest& broadcast_request);
+
+ // Enqueues a reshape instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddReshapeInstruction(
+ const ReshapeRequest& reshape_request);
+
+ // Enqueues a slice instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddSliceInstruction(
+ const SliceRequest& slice_request);
+
+ // Enqueues a dynamic slice instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddDynamicSliceInstruction(
+ const DynamicSliceRequest& dynamic_slice_request);
+
+ // Enqueues a dynamic update slice instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction(
+ const DynamicUpdateSliceRequest& dynamic_update_slice_request);
+
+ // Enqueues a concatenate instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddConcatenateInstruction(
+ const ConcatenateRequest& slice_request);
+
+ // Enqueues a convert instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddConvertInstruction(
+ const ConvertRequest& convert_request);
+
+ // Enqueues a reduce instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddReduceInstruction(
+ const ReduceRequest& reduce_request,
+ const UserComputation& reduction_computation);
+
+ // Enqueues a windowed reduce instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
+ const ReduceWindowRequest& reduce_window_request,
+ const UserComputation& reduction_computation);
+
+ // Enqueues a select-and-scatter instruction onto this user
+ // computation.
+ StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
+ const SelectAndScatterRequest& scatter_to_selected_window_element_request,
+ const UserComputation& select_computation,
+ const UserComputation& scatter_computation);
+
+ // Enqueues a reverse instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddReverseInstruction(
+ const ReverseRequest& reverse_request);
+
+ // Enqueues a while instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddWhileInstruction(
+ const WhileRequest& while_request,
+ const UserComputation& condition_computation,
+ const UserComputation& body_computation);
+
+ // Enqueues a Send instruction onto this user computation.
+ Status AddSendInstruction(const SendRequest& send_request);
+
+ // Enqueues a Recv instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddRecvInstruction(
+ const RecvRequest& recv_request);
+
+ // Returns the user-provided name of this user computation, which is provided
+ // via the XLA computation-building API.
+ const string& name() const { return name_; }
+
+ // Subsequent executions of this computation will compute the value
+ // represented by handle, rather than the last expression enqueued
+ // on the computation.
+ Status SetReturnValue(const ComputationDataHandle& handle);
+
+ // Return a versioned handle for this computation.
+ VersionedComputationHandle GetVersionedHandle() const;
+
+ // Return a versioned handle for this computation with a version equal to the
+ // point at which given operation was added to the computation.
+ VersionedComputationHandle GetVersionedHandleAtOperation(
+ const ComputationDataHandle& operation) const;
+
+ // Return a version value representing the current state of the
+ // computation.
+ VersionedComputationHandle::Version version() const;
+
+ // Computes and returns the program shape for the user computation -- gathers
+ // parameters and result type into a single proto. A shared_ptr is used
+ // because the returned pointer refers to an internally cached value which may
+ // be discarded by the UserComputation object. This avoid unnecessary copies.
+ //
+ // If the parameter space is not dense (i.e. there are holes in the parameter
+ // numbers provided) then an error status is returned.
+ StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape(
+ VersionedComputationHandle::Version version) const;
+
+ // Returns true if the given data handle does not depend on any
+ // parameters. That is, the value can be computed at compile time.
+ StatusOr<bool> IsConstant(const ComputationDataHandle& handle);
+
+ // Returns the output shape of the operation indicated by the given handle.
+ StatusOr<Shape> GetShape(const ComputationDataHandle& handle);
+
+ // Builds a HLO computation from the UserComputation. The parameter "resolver"
+ // is a function which returns a pointer to the HloComputation corresponding
+ // to the given ComputationHandle at the given version. The resolver is used
+ // for operations, such as map, which call other computations and need a
+ // pointer to the called HloComputation to construct the respective HLO
+ // instructions. If include_unused_computation is true, then all parameter
+ // instructions are lowered into HloInstructions even if the parameter is
+ // unused (the root of the computation is unreachable from the parameter).
+ using HloComputationResolver =
+ std::function<HloComputation*(const VersionedComputationHandle& handle)>;
+ StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation(
+ VersionedComputationHandle::Version version,
+ HloComputationResolver hlo_resolver,
+ bool include_unused_parameters = true) const;
+
+ // Return a vector containing the embedded computations used by this
+ // UserComputation. Only embedded computations which are called directly by
+ // this UserComputation are included. That is, the transitive closure of
+ // embedded computations is not included.
+ std::vector<VersionedComputationHandle> GetEmbeddedComputations(
+ VersionedComputationHandle::Version version) const;
+
+ // Returns the number of OperationRequest objects in this UserComputation.
+ // The 'version' of a computation is identical to the number of
+ // OperationRequests in the UserComputation.
+ int64 request_count(VersionedComputationHandle::Version version) const {
+ return version;
+ }
+
+ // Returns a copy of the internal session state for this computation -- this
+ // is useful for serializing the guts of a user computation, though references
+ // to other handles (e.g. referred-to computations) must be handled with care
+ // in the serialization / de-serialization process.
+ SessionComputation CloneSessionComputation(
+ VersionedComputationHandle::Version version) const;
+
+ private:
+ // Warning: dangerous mutating operation that doesn't respect versioning.
+ // This is only used at initialization time when constructing from a
+ // SessionComputation a la MakeWithRemapping.
+ //
+ // Remaps references to old computations (with handle values in the keys of
+ // old_to_new) to the computation handle given in the values. This is useful
+ // when loading computations from snapshots, to finish initialization, before
+ // the user computation is released into the wild.
+ Status RemapEmbeddedComputations(
+ const std::map<int64, ComputationHandle>& old_to_new)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Returns the OperationRequestion corresponding to the root (result) of the
+ // computation.
+ const OperationRequest& GetRoot(VersionedComputationHandle::Version version)
+ const EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Returns the OperationRequest corresponding to the given handle value.
+ StatusOr<const OperationRequest*> LookupRequest(
+ const ComputationDataHandle& handle) const
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Creates a new ComputationDataHandle with the next available handle value.
+ ComputationDataHandle CreateComputationDataHandle()
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Checks whether the parameter numbers of the parameter operations are
+ // contiguous starting from zero. Returns appropriate error status if not.
+ Status CheckParametersAreContiguous(
+ VersionedComputationHandle::Version version) const
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ // Name of the computation.
+ string name_;
+
+ mutable tensorflow::mutex mutex_;
+
+ // State of the computation as a record of all operation-building requests.
+ SessionComputation session_computation_ GUARDED_BY(mutex_);
+
+ // Mapping from parameter number to operation request containing the
+ // respective ParameterRequest.
+ std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_);
+
+ // The next ComputationDataHandle value to assign. Handle values are assigned
+ // sequentially.
+ int64 next_handle_value_ GUARDED_BY(mutex_);
+
+ // If handle_to_return_.has_handle() then an Execution of this Computation
+ // will compute the value represented by handle_to_return_, otherwise it will
+ // compute the value of (next_handle_value_ - 1).
+ ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_);
+
+ // Memoized ProgramShape and its version. A shared_ptr is used because
+ // references to this object are returned by ComputeProgramShape.
+ mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0;
+ mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(UserComputation);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
diff --git a/tensorflow/compiler/xla/service/versioned_computation_handle.h b/tensorflow/compiler/xla/service/versioned_computation_handle.h
new file mode 100644
index 0000000000..03bee3d4a5
--- /dev/null
+++ b/tensorflow/compiler/xla/service/versioned_computation_handle.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
+
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A data structure encapsulating a ComputationHandle and version value of that
+// computation. This object is used to unambiguously refer to a particular
+// computation in the service.
+struct VersionedComputationHandle {
+ // A version value unambiguously specifying the state of the computation at a
+ // particular point in time as it is being built. This value is the
+ // ComputationDataHandle of the current root instruction.
+ using Version = int64;
+
+ ComputationHandle handle;
+ Version version;
+ bool operator==(const VersionedComputationHandle& other) const {
+ return (handle.handle() == other.handle.handle()) &&
+ (version == other.version);
+ }
+ bool operator<(const VersionedComputationHandle& other) const {
+ return ((handle.handle() < other.handle.handle()) ||
+ ((handle.handle() == other.handle.handle()) &&
+ (version < other.version)));
+ }
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_VERSIONED_COMPUTATION_HANDLE_H_
diff --git a/tensorflow/compiler/xla/service_interface.h b/tensorflow/compiler/xla/service_interface.h
new file mode 100644
index 0000000000..fc107480f7
--- /dev/null
+++ b/tensorflow/compiler/xla/service_interface.h
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_
+
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+
+// Defines the interface for an XLA service.
+class ServiceInterface {
+ public:
+ ServiceInterface() {}
+ virtual ~ServiceInterface() = default;
+
+ // TODO(b/31824348): Convert to use StatusOr.
+ virtual tensorflow::Status TransferToClient(
+ const TransferToClientRequest* arg, TransferToClientResponse* result) = 0;
+
+ virtual tensorflow::Status TransferToClientInProcess(
+ const TransferToClientInProcessRequest* arg,
+ TransferToClientInProcessResponse* result) = 0;
+
+ virtual tensorflow::Status TransferToServer(
+ const TransferToServerRequest* arg, TransferToServerResponse* result) = 0;
+
+ virtual tensorflow::Status TransferToInfeed(
+ const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) = 0;
+
+ virtual tensorflow::Status ResetDevice(const ResetDeviceRequest* arg,
+ ResetDeviceResponse* result) = 0;
+
+ virtual tensorflow::Status TransferToServerInProcess(
+ const TransferToServerInProcessRequest* arg,
+ TransferToServerInProcessResponse* result) = 0;
+
+ virtual tensorflow::Status LoadComputationSnapshot(
+ const LoadComputationSnapshotRequest* request,
+ LoadComputationSnapshotResponse* result) = 0;
+
+ virtual tensorflow::Status Execute(const ExecuteRequest* arg,
+ ExecuteResponse* result) = 0;
+
+ virtual tensorflow::Status ExecuteParallel(
+ const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) = 0;
+
+ virtual tensorflow::Status ExecuteAsync(const ExecuteAsyncRequest* arg,
+ ExecuteAsyncResponse* result) = 0;
+
+ virtual tensorflow::Status WaitForExecution(
+ const WaitForExecutionRequest* arg, WaitForExecutionResponse* result) = 0;
+
+ virtual tensorflow::Status DeconstructTuple(
+ const DeconstructTupleRequest* arg, DeconstructTupleResponse* result) = 0;
+
+ virtual tensorflow::Status GetComputationStats(
+ const ComputationStatsRequest* arg, ComputationStatsResponse* result) = 0;
+
+ virtual tensorflow::Status GetComputationShape(
+ const GetComputationShapeRequest* arg,
+ GetComputationShapeResponse* result) = 0;
+
+ virtual tensorflow::Status GetShape(const GetShapeRequest* arg,
+ GetShapeResponse* result) = 0;
+
+ virtual tensorflow::Status CreateChannelHandle(
+ const CreateChannelHandleRequest* arg,
+ CreateChannelHandleResponse* result) = 0;
+
+ virtual tensorflow::Status GetDeviceHandles(
+ const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) = 0;
+
+ // Methods used by ComputationBuilder.
+ virtual tensorflow::Status Computation(const ComputationRequest* arg,
+ ComputationResponse* result) = 0;
+
+ virtual tensorflow::Status Op(const OpRequest* arg, OpResponse* result) = 0;
+
+ virtual tensorflow::Status GetLocalShape(const GetLocalShapeRequest* arg,
+ GetLocalShapeResponse* result) = 0;
+
+ virtual tensorflow::Status SetReturnValue(
+ const SetReturnValueRequest* arg, SetReturnValueResponse* results) = 0;
+
+ virtual tensorflow::Status IsConstant(const IsConstantRequest* arg,
+ IsConstantResponse* result) = 0;
+
+ virtual tensorflow::Status ComputeConstant(
+ const ComputeConstantRequest* arg, ComputeConstantResponse* result) = 0;
+
+ // Methods used by Computation.
+ virtual tensorflow::Status SnapshotComputation(
+ const SnapshotComputationRequest* ag,
+ SnapshotComputationResponse* result) = 0;
+
+ // Methods used by GlobalData.
+ virtual tensorflow::Status Unregister(const UnregisterRequest* arg,
+ UnregisterResponse* result) = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INTERFACE_H_
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
new file mode 100644
index 0000000000..5bf9842a6c
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_layout.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) {
+ if (!ShapeUtil::Compatible(other_shape, shape_)) {
+ return InvalidArgument("Shape %s is not compatible with shape %s",
+ ShapeUtil::HumanString(other_shape).c_str(),
+ ShapeUtil::HumanString(shape()).c_str());
+ }
+ shape_ = other_shape;
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const {
+ if (!ShapeUtil::Compatible(*other_shape, shape_)) {
+ return InvalidArgument("Shape %s is not compatible with shape %s",
+ ShapeUtil::HumanString(*other_shape).c_str(),
+ ShapeUtil::HumanString(shape()).c_str());
+ }
+ *other_shape = shape_;
+ return tensorflow::Status::OK();
+}
+
+void ShapeLayout::SetToDefaultLayout() {
+ LayoutUtil::SetToDefaultLayout(&shape_);
+}
+
+bool ShapeLayout::MatchesLayoutInShape(const Shape& shape) const {
+ return ShapeUtil::Equal(shape, shape_);
+}
+
+const Layout& ShapeLayout::layout() const {
+ CHECK(LayoutIsSet());
+ CHECK(!ShapeUtil::IsTuple(shape_));
+ return shape_.layout();
+}
+
+void ShapeLayout::Clear() { LayoutUtil::ClearLayout(&shape_); }
+
+bool ShapeLayout::LayoutIsSet() const { return LayoutUtil::HasLayout(shape_); }
+
+void ShapeLayout::ResetLayout(const Layout& layout) {
+ CHECK(!ShapeUtil::IsTuple(shape_));
+ CHECK(!ShapeUtil::IsOpaque(shape_));
+ *shape_.mutable_layout() = layout;
+ TF_CHECK_OK(ShapeUtil::ValidateShape(shape_));
+}
+
+bool ShapeLayout::operator==(const ShapeLayout& other) const {
+ return ShapeUtil::Equal(shape_, other.shape_);
+}
+
+bool ShapeLayout::operator!=(const ShapeLayout& other) const {
+ return !ShapeUtil::Equal(shape_, other.shape_);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h
new file mode 100644
index 0000000000..92564660f2
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_layout.h
@@ -0,0 +1,88 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_
+#define TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace xla {
+
+// A ShapeLayout object encapsulates the layout of a particular shape (including
+// tuples). This differs from the Layout proto which describes the layout of a
+// single array. ShapeLayout contains a Layout proto for each array in the shape
+// (a tuple can have more than one array). For array shapes, this object
+// trivially holds a single Layout. Logically, ShapeLayout holds a nonmutable
+// shape with mutable layouts.
+class ShapeLayout {
+ public:
+ // Constructs a ShapeLayout of the given shape. Layouts are copied from the
+ // shape parameter.
+ explicit ShapeLayout(const Shape& shape) : shape_(shape) {}
+
+ // Assigns the layouts in this ShapeLayout to the Layout fields of the given
+ // shape. 'shape' and the shape of the ShapeLayout object must be compatible.
+ tensorflow::Status AssignLayoutToShape(Shape* shape) const;
+
+ // Returns true if the Layouts in this ShapeLayout match the layouts in the
+ // given shape. Returns false otherwise. If the given shape is not compatible
+ // with the ShapeLayout's shape, then false is returned.
+ bool MatchesLayoutInShape(const Shape& shape) const;
+
+ // Copies the layout from the given shape into this ShapeLayout. 'shape' must
+ // be compatible with the ShapeLayout's shape, and 'shape' must have a layout
+ // (LayoutUtil::HasLayout).
+ tensorflow::Status CopyLayoutFromShape(const Shape& shape);
+
+ // Clears (Layout::Clear) all the Layouts stored in this object.
+ void Clear();
+
+ // Sets all Layouts stored in this object to the default layout.
+ void SetToDefaultLayout();
+
+ // Returns the shape (with layouts).
+ const Shape& shape() const { return shape_; }
+
+ // Checks that a layout is set for the shape, and returns a reference to the
+ // layout directly on the shape. Shape must not be a tuple.
+ const Layout& layout() const;
+
+ // Returns true if all layouts have been set for this ShapeLayout object. That
+ // is, every array has a layout.
+ bool LayoutIsSet() const;
+
+ // Resets the layout on the shape to the provided layout. Shape must not be a
+ // tuple.
+ void ResetLayout(const Layout& layout);
+
+ // Returns a string representation of this object.
+ string ToString() const { return ShapeUtil::HumanStringWithLayout(shape_); }
+
+ // Tests for equality of both shape and layout (ShapeUtil::Equal).
+ bool operator==(const ShapeLayout& other) const;
+ bool operator!=(const ShapeLayout& other) const;
+
+ private:
+ Shape shape_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SHAPE_LAYOUT_H_
diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h
new file mode 100644
index 0000000000..6963a68d10
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_tree.h
@@ -0,0 +1,260 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
+#define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A ShapeTree<T> is a recursive data structure which mirrors the structure of a
+// XLA shape and holds a value of type T for each array in the shape. For
+// array shapes, a ShapeTree trivially holds a single value of type T. For tuple
+// shapes which can be an arbitrary tree with arrays at the leaves, a ShapeTree
+// is an identically structured tree with data elements of type T at the leaves.
+//
+// Like the Shape data structure, this is a tree and tuple elements cannot be
+// duplicated. That is, every distinct element position in the Shape has a
+// unique T object.
+template <typename T>
+class ShapeTree {
+ public:
+ explicit ShapeTree(const Shape& shape);
+ ShapeTree(const Shape& shape, const T& init_value);
+ ShapeTree(const ShapeTree<T>& other);
+ ShapeTree<T>& operator=(const ShapeTree<T>& other);
+
+ // Returns the data element associated with the array in the shape at the
+ // given index (see ShapeUtil::GetSubshape for how indexes are defined).
+ const T& element(const ShapeIndex& index) const;
+ T* mutable_element(const ShapeIndex& index);
+
+ // Return the shape represented with this ShapeTree.
+ const Shape& shape() const { return *shape_; }
+
+ // Returns true if the node at the given index is a leaf node (an array
+ // shape).
+ bool IsLeaf(const ShapeIndex& index) const {
+ return Lookup(index).elements_.empty();
+ }
+
+ // Recursively traverses the shape and calls the given function at each
+ // element. The function has the following arguments:
+ //
+ // index : the index of the element in the shape. See ShapeUtil::GetSubshape
+ // for definition of index.
+ // is_leaf : Whether this element is a leaf element in the shape. That is,
+ // whether this index corresponds to an array and not a (nested)
+ // tuple element.
+ // data : The data value at this elemnt.
+ //
+ // If any call to the given function returns a non-OK status, then traversal
+ // is aborted and the status value is returned.
+ using VisitorFunction = std::function<tensorflow::Status(
+ const ShapeIndex& /*index*/, bool /*is_leaf*/, const T& /*data*/)>;
+ tensorflow::Status ForEachElement(VisitorFunction func) const;
+
+ using MutableVisitorFunction = std::function<tensorflow::Status(
+ const ShapeIndex& /*index*/, bool /*is_leaf*/, T* /*data*/)>;
+ tensorflow::Status ForEachMutableElement(MutableVisitorFunction func);
+
+ private:
+ // Private default constructor for non-root nodes of the tree.
+ ShapeTree() = default;
+
+ // Helpers for traversing the shape via ForEachElement. The helpers
+ // recursively traverse the subtree rooted at "index" (defined as in
+ // ShapeUtil::GetSubshape).
+ static tensorflow::Status ForEachHelperMutable(ShapeIndex* index,
+ ShapeTree<T>* shape_tree,
+ MutableVisitorFunction func);
+ static tensorflow::Status ForEachHelper(ShapeIndex* index,
+ const ShapeTree<T>& shape_tree,
+ VisitorFunction func);
+
+ // Copy all the data elements (of type T) from "other" into "this". "this"
+ // must have the same tree structure as "other" prior to calling this method.
+ void CopyDataElements(const ShapeTree<T>& other);
+
+ // Recursive helper for constructing a subtree beneath "this" node.
+ void BuildTree(const Shape& shape);
+
+ // Return the tree node at the given index.
+ ShapeTree<T>& Lookup(const ShapeIndex& index);
+ const ShapeTree<T>& Lookup(const ShapeIndex& index) const;
+
+ // The data corresponding to the array at this node.
+ T data_;
+
+ // The XLA shape mirrored in this ShapeTree. Only the root of the
+ // ShapeTree has this member set.
+ std::unique_ptr<Shape> shape_;
+
+ // The children of this node in the tree.
+ std::vector<std::unique_ptr<ShapeTree>> elements_;
+};
+
+template <typename T>
+void ShapeTree<T>::BuildTree(const Shape& shape) {
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ elements_.emplace_back(new ShapeTree());
+ elements_.back()->BuildTree(shape.tuple_shapes(i));
+ }
+ }
+}
+
+template <typename T>
+ShapeTree<T>::ShapeTree(const Shape& shape) : shape_(MakeUnique<Shape>(shape)) {
+ // The shape_ field is just used to hold the structure of the shape. It should
+ // not be relied upon to store layout information.
+ LayoutUtil::ClearLayout(shape_.get());
+ BuildTree(*shape_);
+}
+
+template <typename T>
+ShapeTree<T>::ShapeTree(const Shape& shape, const T& init_value)
+ : shape_(MakeUnique<Shape>(shape)) {
+ LayoutUtil::ClearLayout(shape_.get());
+ BuildTree(*shape_);
+ TF_CHECK_OK(ForEachMutableElement(
+ [&init_value](const ShapeIndex& /*index*/, bool /*is_leaf*/, bool* data) {
+ *data = init_value;
+ return tensorflow::Status::OK();
+ }));
+}
+
+template <typename T>
+ShapeTree<T>::ShapeTree(const ShapeTree& other)
+ : shape_(MakeUnique<Shape>(other.shape())) {
+ LayoutUtil::ClearLayout(shape_.get());
+ BuildTree(*shape_);
+ CopyDataElements(other);
+}
+
+template <typename T>
+ShapeTree<T>& ShapeTree<T>::operator=(const ShapeTree<T>& other) {
+ if (this == &other) {
+ return *this;
+ }
+ elements_.clear();
+ shape_ = MakeUnique<Shape>(other.shape());
+ LayoutUtil::ClearLayout(shape_.get());
+
+ BuildTree(*shape_);
+ CopyDataElements(other);
+ return *this;
+}
+
+template <typename T>
+void ShapeTree<T>::CopyDataElements(const ShapeTree<T>& other) {
+ CHECK(ShapeUtil::Compatible(shape(), other.shape()));
+ TF_CHECK_OK(ForEachMutableElement(
+ [&other](const ShapeIndex& index, bool /*is_leaf*/, T* data) {
+ *data = other.element(index);
+ return tensorflow::Status::OK();
+ }));
+}
+
+template <typename T>
+const T& ShapeTree<T>::element(const ShapeIndex& index) const {
+ return Lookup(index).data_;
+}
+
+template <typename T>
+T* ShapeTree<T>::mutable_element(const ShapeIndex& index) {
+ return &Lookup(index).data_;
+}
+
+template <typename T>
+ShapeTree<T>& ShapeTree<T>::Lookup(const ShapeIndex& index) {
+ ShapeTree<T>* node = this;
+ for (auto& i : index) {
+ CHECK_GE(i, 0);
+ CHECK_LT(i, node->elements_.size());
+ node = node->elements_[i].get();
+ }
+ return *node;
+}
+
+template <typename T>
+const ShapeTree<T>& ShapeTree<T>::Lookup(const ShapeIndex& index) const {
+ return const_cast<ShapeTree<T>*>(this)->Lookup(index);
+}
+
+/* static */
+template <typename T>
+tensorflow::Status ShapeTree<T>::ForEachHelperMutable(
+ ShapeIndex* index, ShapeTree<T>* shape_tree,
+ ShapeTree<T>::MutableVisitorFunction func) {
+ TF_RETURN_IF_ERROR(
+ func(*index, shape_tree->elements_.empty(), &shape_tree->data_));
+ for (int i = 0; i < shape_tree->elements_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(
+ ForEachHelperMutable(index, shape_tree->elements_[i].get(), func));
+ index->pop_back();
+ }
+
+ return tensorflow::Status::OK();
+}
+
+/* static */
+template <typename T>
+tensorflow::Status ShapeTree<T>::ForEachHelper(
+ ShapeIndex* index, const ShapeTree<T>& shape_tree,
+ ShapeTree<T>::VisitorFunction func) {
+ TF_RETURN_IF_ERROR(
+ func(*index, shape_tree.elements_.empty(), shape_tree.data_));
+ for (int i = 0; i < shape_tree.elements_.size(); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachHelper(index, *shape_tree.elements_[i], func));
+ index->pop_back();
+ }
+
+ return tensorflow::Status::OK();
+}
+
+template <typename T>
+tensorflow::Status ShapeTree<T>::ForEachElement(
+ ShapeTree<T>::VisitorFunction func) const {
+ ShapeIndex index;
+ return ForEachHelper(&index, *this, func);
+}
+
+template <typename T>
+tensorflow::Status ShapeTree<T>::ForEachMutableElement(
+ ShapeTree<T>::MutableVisitorFunction func) {
+ ShapeIndex index;
+ return ForEachHelperMutable(&index, this, func);
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
diff --git a/tensorflow/compiler/xla/shape_tree_test.cc b/tensorflow/compiler/xla/shape_tree_test.cc
new file mode 100644
index 0000000000..d37f536b75
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_tree_test.cc
@@ -0,0 +1,134 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_tree.h"
+
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class ShapeTreeTest : public ::testing::Test {
+ protected:
+ ShapeTreeTest() {
+ array_shape_ = ShapeUtil::MakeShape(F32, {42, 42, 123});
+ tuple_shape_ =
+ ShapeUtil::MakeTupleShape({array_shape_, array_shape_, array_shape_});
+ nested_tuple_shape_ = ShapeUtil::MakeTupleShape(
+ {array_shape_, ShapeUtil::MakeTupleShape({array_shape_, array_shape_}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({array_shape_, array_shape_}),
+ array_shape_})});
+ }
+
+ // An array shape (non-tuple).
+ Shape array_shape_;
+
+ // A three element tuple shape.
+ Shape tuple_shape_;
+
+ // A nested tuple shape of the following form: (a, (c, d), ((e, f), g))
+ Shape nested_tuple_shape_;
+};
+
+TEST_F(ShapeTreeTest, ArrayShape) {
+ ShapeTree<int> shape_tree{array_shape_};
+ *shape_tree.mutable_element({}) = 42;
+ EXPECT_EQ(42, shape_tree.element({}));
+ *shape_tree.mutable_element({}) = 123;
+ EXPECT_EQ(123, shape_tree.element({}));
+
+ EXPECT_TRUE(ShapeUtil::Compatible(array_shape_, shape_tree.shape()));
+
+ // Test the copy constructor.
+ ShapeTree<int> copy{shape_tree};
+ EXPECT_EQ(123, copy.element({}));
+}
+
+TEST_F(ShapeTreeTest, TupleShape) {
+ ShapeTree<int> shape_tree{tuple_shape_};
+ *shape_tree.mutable_element({}) = 1;
+ *shape_tree.mutable_element({0}) = 42;
+ *shape_tree.mutable_element({1}) = 123;
+ *shape_tree.mutable_element({2}) = -100;
+ EXPECT_EQ(1, shape_tree.element({}));
+ EXPECT_EQ(42, shape_tree.element({0}));
+ EXPECT_EQ(123, shape_tree.element({1}));
+ EXPECT_EQ(-100, shape_tree.element({2}));
+
+ EXPECT_TRUE(ShapeUtil::Compatible(tuple_shape_, shape_tree.shape()));
+
+ // Sum all elements in the shape.
+ int sum = 0;
+ TF_CHECK_OK(shape_tree.ForEachElement(
+ [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int data) {
+ sum += data;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(66, sum);
+
+ // Test the copy constructor.
+ ShapeTree<int> copy{shape_tree};
+ EXPECT_EQ(1, copy.element({}));
+ EXPECT_EQ(42, copy.element({0}));
+ EXPECT_EQ(123, copy.element({1}));
+ EXPECT_EQ(-100, copy.element({2}));
+
+ // Write zero to all data elements.
+ TF_CHECK_OK(shape_tree.ForEachMutableElement(
+ [&sum](const ShapeIndex& /*index*/, bool /*is_leaf*/, int* data) {
+ *data = 0;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(0, shape_tree.element({}));
+ EXPECT_EQ(0, shape_tree.element({0}));
+ EXPECT_EQ(0, shape_tree.element({1}));
+ EXPECT_EQ(0, shape_tree.element({2}));
+}
+
+TEST_F(ShapeTreeTest, NestedTupleShape) {
+ ShapeTree<int> shape_tree{nested_tuple_shape_};
+ *shape_tree.mutable_element({0}) = 42;
+ *shape_tree.mutable_element({1, 1}) = 123;
+ *shape_tree.mutable_element({2, 0, 1}) = -100;
+ EXPECT_EQ(42, shape_tree.element({0}));
+ EXPECT_EQ(123, shape_tree.element({1, 1}));
+ EXPECT_EQ(-100, shape_tree.element({2, 0, 1}));
+
+ EXPECT_TRUE(ShapeUtil::Compatible(nested_tuple_shape_, shape_tree.shape()));
+
+ // Test the copy constructor.
+ ShapeTree<int> copy{shape_tree};
+ EXPECT_EQ(42, copy.element({0}));
+ EXPECT_EQ(123, copy.element({1, 1}));
+ EXPECT_EQ(-100, copy.element({2, 0, 1}));
+}
+
+TEST_F(ShapeTreeTest, InvalidIndexingTuple) {
+ ShapeTree<int> shape_tree{tuple_shape_};
+
+ EXPECT_DEATH(shape_tree.element({4}), "");
+}
+
+TEST_F(ShapeTreeTest, InvalidIndexingNestedTuple) {
+ ShapeTree<int> shape_tree{nested_tuple_shape_};
+
+ EXPECT_DEATH(shape_tree.element({0, 0}), "");
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
new file mode 100644
index 0000000000..a8878e7941
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -0,0 +1,1024 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include <algorithm>
+#include <functional>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace xla {
+
+/* static */ bool ShapeUtil::CompareShapes(const Shape& lhs, const Shape& rhs,
+ bool compare_layouts) {
+ if (IsTuple(lhs)) {
+ return IsTuple(rhs) &&
+ ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
+ [=](const Shape& l, const Shape& r) {
+ return CompareShapes(l, r, compare_layouts);
+ });
+ }
+ // Explicitly compare the fields rather than using MessageDifferencer because
+ // we want empty layouts to be treated identically to missing layouts.
+ if (compare_layouts &&
+ (!ContainersEqual(lhs.layout().minor_to_major(),
+ rhs.layout().minor_to_major()) ||
+ !ContainersEqual(lhs.layout().padded_dimensions(),
+ rhs.layout().padded_dimensions()) ||
+ lhs.layout().padding_value() != rhs.layout().padding_value())) {
+ return false;
+ }
+ return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
+}
+
+/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
+ bool equal = CompareShapes(lhs, rhs, /*compare_layouts=*/true);
+ if (!equal && VLOG_IS_ON(3)) {
+ // TODO(jeff): Maybe print more info about where lhs and rhs differ
+ VLOG(3) << "ShapeUtil::Equal differ: lhs = " << lhs.ShortDebugString()
+ << ", rhs = " << rhs.ShortDebugString();
+ }
+
+ return equal;
+}
+
+/* static */ int64 ShapeUtil::TrueRank(const Shape& shape) {
+ int64 accum = 0;
+ for (int64 dimension : shape.dimensions()) {
+ // We do not count zero dimensions.
+ if (dimension != 1) {
+ accum += 1;
+ }
+ }
+ return accum;
+}
+
+/* static */ ProgramShape ShapeUtil::MakeProgramShape(
+ std::initializer_list<Shape> parameters, Shape result) {
+ ProgramShape program_shape;
+ for (const auto& shape : parameters) {
+ *program_shape.add_parameters() = shape;
+ }
+ *program_shape.mutable_result() = result;
+ return program_shape;
+}
+
+/* static */ Shape ShapeUtil::MakeShape(
+ PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ DCHECK_NE(TUPLE, element_type);
+ DCHECK_NE(OPAQUE, element_type);
+ Shape result;
+ PopulateShape(element_type, dimensions, &result);
+ return result;
+}
+
+/* static */ Shape ShapeUtil::MakeShapeWithLayout(
+ PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ CHECK_EQ(dimensions.size(), minor_to_major.size());
+ Shape shape = MakeShape(element_type, dimensions);
+ auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
+ min2maj->Clear();
+ for (int64 value : minor_to_major) {
+ min2maj->Add(value);
+ }
+ DCHECK(shape.has_layout());
+ TF_DCHECK_OK(ValidateShape(shape));
+ return shape;
+}
+
+/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions) {
+ std::vector<int64> layout(dimensions.size());
+ std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
+ return MakeShapeWithLayout(element_type, dimensions, layout);
+}
+
+/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(
+ const Shape& shape) {
+ std::vector<int64> dims(shape.dimensions_size());
+ for (int i = 0; i < shape.dimensions_size(); ++i) {
+ dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i));
+ }
+ return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims);
+}
+/* static */ void ShapeUtil::PopulateShape(
+ PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
+ Shape* shape) {
+ shape->Clear();
+ shape->set_element_type(element_type);
+ for (int64 dimension : dimensions) {
+ shape->add_dimensions(dimension);
+ }
+ LayoutUtil::SetToDefaultLayout(shape);
+ TF_DCHECK_OK(ValidateShape(*shape));
+}
+
+/* static */ Shape ShapeUtil::MakeTupleShape(
+ tensorflow::gtl::ArraySlice<Shape> shapes) {
+ Shape result;
+ result.set_element_type(TUPLE);
+ for (const auto& shape : shapes) {
+ AppendShapeToTuple(shape, &result);
+ }
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
+ return result;
+}
+
+/* static */ Shape ShapeUtil::MakeOpaqueShape() {
+ Shape result;
+ result.set_element_type(OPAQUE);
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(result));
+ return result;
+}
+
+/* static */ void ShapeUtil::AppendShapeToTuple(const Shape& shape,
+ Shape* tuple_shape) {
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
+ *tuple_shape->add_tuple_shapes() = shape;
+}
+
+/* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) {
+ shape->mutable_layout()->add_minor_to_major(ShapeUtil::Rank(*shape));
+ shape->add_dimensions(bound);
+ TF_DCHECK_OK(ValidateShape(*shape));
+}
+
+/* static */ bool ShapeUtil::ElementIsIntegral(const Shape& shape) {
+ return primitive_util::IsIntegralType(shape.element_type());
+}
+
+/* static */ bool ShapeUtil::ElementIsIntegralWithBits(const Shape& shape,
+ int32 bits) {
+ return ElementIsIntegral(shape) && ElementHasBitWidth(shape, bits);
+}
+
+/* static */ bool ShapeUtil::ElementHasBitWidth(const Shape& shape, int bits) {
+ if (shape.element_type() == TUPLE || shape.element_type() == OPAQUE) {
+ return false;
+ }
+ return primitive_util::BitWidth(shape.element_type()) == bits;
+}
+
+/* static */ bool ShapeUtil::ElementIsSigned(const Shape& shape) {
+ switch (shape.element_type()) {
+ case S8:
+ case S16:
+ case S32:
+ case S64:
+ case F16:
+ case F32:
+ case F64:
+ return true;
+
+ case PRED:
+ case U8:
+ case U16:
+ case U32:
+ case U64:
+ case TUPLE:
+ case OPAQUE:
+ return false;
+
+ default:
+ LOG(FATAL) << "Unhandled element type " << shape.element_type();
+ }
+}
+
+/* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) {
+ return primitive_util::IsFloatingPointType(shape.element_type());
+}
+
+/* static */ bool ShapeUtil::IsTuple(const Shape& shape) {
+ return shape.element_type() == TUPLE;
+}
+
+/* static */ bool ShapeUtil::IsArray(const Shape& shape) {
+ return !IsTuple(shape) && !IsOpaque(shape);
+}
+
+/* static */ bool ShapeUtil::IsNestedTuple(const Shape& shape) {
+ return IsTuple(shape) && std::any_of(shape.tuple_shapes().begin(),
+ shape.tuple_shapes().end(), IsTuple);
+}
+
+/* static */ bool ShapeUtil::IsEmptyTuple(const Shape& shape) {
+ return IsTuple(shape) && TupleElementCount(shape) == 0;
+}
+
+/* static */ bool ShapeUtil::IsNil(const Shape& shape) {
+ return IsEmptyTuple(shape) || HasZeroElements(shape);
+}
+
+/* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) {
+ CHECK(IsTuple(shape));
+ return shape.tuple_shapes_size();
+}
+
+/* static */ const Shape& ShapeUtil::GetTupleElementShape(const Shape& shape,
+ int64 index) {
+ CHECK(IsTuple(shape));
+ CHECK_GT(TupleElementCount(shape), index);
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape.tuple_shapes(index)));
+ return shape.tuple_shapes(index);
+}
+
+/* static */ Shape ShapeUtil::SliceTuple(const Shape& tuple, int64 start,
+ int64 limit) {
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(tuple));
+ CHECK(IsTuple(tuple));
+ CHECK_LE(start, TupleElementCount(tuple));
+ CHECK_LE(limit, TupleElementCount(tuple));
+
+ std::vector<Shape> new_elements(tuple.tuple_shapes().begin() + start,
+ tuple.tuple_shapes().begin() + limit);
+ return ShapeUtil::MakeTupleShape(new_elements);
+}
+
+/* static */ bool ShapeUtil::IsOpaque(const Shape& shape) {
+ return shape.element_type() == OPAQUE;
+}
+
+/* static */ bool ShapeUtil::ShapeIs(const Shape& shape,
+ PrimitiveType element_type,
+ std::initializer_list<int64> dimensions) {
+ TF_DCHECK_OK(ValidateShapeWithOptionalLayout(shape));
+ if (shape.element_type() != element_type) {
+ return false;
+ }
+ if (shape.dimensions_size() != ShapeUtil::Rank(shape)) {
+ return false;
+ }
+ int64 i = 0;
+ for (int64 dimension : dimensions) {
+ if (shape.dimensions(i) != dimension) {
+ return false;
+ }
+ i += 1;
+ }
+ return true;
+}
+
+/* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) {
+ CHECK_EQ(shape.dimensions_size(), ShapeUtil::Rank(shape));
+ return std::accumulate<decltype(shape.dimensions().begin()), int64>(
+ shape.dimensions().begin(), shape.dimensions().end(), 1LL,
+ std::multiplies<int64>());
+}
+
+/* static */ bool ShapeUtil::HasZeroElements(const Shape& shape) {
+ return ElementsIn(shape) == 0;
+}
+
+/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
+ return shape.element_type() == F32 && ShapeUtil::Rank(shape) == 0;
+}
+
+/* static */ string ShapeUtil::HumanString(const Shape& shape) {
+ if (shape.element_type() == TUPLE) {
+ string text = "(";
+ const char* prefix = "";
+ for (const Shape& elem_shape : shape.tuple_shapes()) {
+ tensorflow::strings::StrAppend(&text, prefix, HumanString(elem_shape));
+ prefix = ", ";
+ }
+ text += ")";
+ return text;
+ } else {
+ return tensorflow::strings::StrCat(
+ tensorflow::str_util::Lowercase(
+ PrimitiveType_Name(shape.element_type())),
+ "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]");
+ }
+}
+
+/* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) {
+ if (shape.element_type() == TUPLE) {
+ string text = "(";
+ const char* prefix = "";
+ for (const Shape& elem_shape : shape.tuple_shapes()) {
+ tensorflow::strings::StrAppend(&text, prefix,
+ HumanStringWithLayout(elem_shape));
+ prefix = ", ";
+ }
+ text += ")";
+ return text;
+ } else {
+ string layout;
+ if (!IsScalar(shape) && !IsOpaque(shape)) {
+ if (LayoutUtil::HasLayout(shape)) {
+ layout = tensorflow::strings::StrCat(
+ " ", LayoutUtil::HumanString(shape.layout()));
+ } else {
+ layout = " (no layout)";
+ }
+ }
+ return tensorflow::strings::StrCat(
+ tensorflow::str_util::Lowercase(
+ PrimitiveType_Name(shape.element_type())),
+ "[", tensorflow::str_util::Join(shape.dimensions(), ","), "]", layout);
+ }
+}
+
+/* static */ string ShapeUtil::HumanString(const ProgramShape& program_shape) {
+ std::vector<string> parameters;
+ for (auto& shape : program_shape.parameters()) {
+ const int i = parameters.size();
+ parameters.push_back(
+ tensorflow::strings::StrCat(i < program_shape.parameter_names_size()
+ ? program_shape.parameter_names(i)
+ : "(unknown)",
+ ": ", HumanString(shape)));
+ }
+ return tensorflow::strings::StrCat(
+ "(", tensorflow::str_util::Join(parameters, ", "), ") -> ",
+ HumanString(program_shape.result()));
+}
+
+/* static */ StatusOr<Shape> ShapeUtil::ParseShapeString(const string& s) {
+ string element_type_string;
+ string dimensions_string;
+ string layout_string;
+ if (RE2::FullMatch(s, "([fsu]32)\\[([\\d,]*)\\](?: {([\\d,]*)})?",
+ &element_type_string, &dimensions_string,
+ &layout_string)) {
+ auto comma_list_to_int64s =
+ [&s](const string& input) -> StatusOr<std::vector<int64>> {
+ std::vector<int64> results;
+ for (const string& piece : tensorflow::str_util::Split(input, ',')) {
+ int64 element;
+ if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) {
+ return InvalidArgument(
+ "invalid value in parsed shape string: \"%s\" in \"%s\"",
+ piece.c_str(), s.c_str());
+ }
+ results.push_back(element);
+ }
+ return results;
+ };
+ TF_ASSIGN_OR_RETURN(std::vector<int64> dimensions,
+ comma_list_to_int64s(dimensions_string));
+ PrimitiveType primitive_type;
+ if (element_type_string == "f32") {
+ primitive_type = F32;
+ } else if (element_type_string == "s32") {
+ primitive_type = S32;
+ } else if (element_type_string == "u32") {
+ primitive_type = U32;
+ } else {
+ LOG(FATAL) << "unhandled element type string: " << element_type_string;
+ }
+ Shape result;
+ if (layout_string.empty()) {
+ result = ShapeUtil::MakeShape(primitive_type, dimensions);
+ } else {
+ TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
+ comma_list_to_int64s(layout_string));
+ TF_RET_CHECK(dimensions.size() == min2maj.size());
+ result =
+ ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
+ }
+ TF_DCHECK_OK(ValidateShape(result));
+ return result;
+ }
+
+ return InvalidArgument("invalid shape string to parse: \"%s\"", s.c_str());
+}
+
+/* static */ bool ShapeUtil::SameDimensions(const Shape& lhs,
+ const Shape& rhs) {
+ return ContainersEqual(lhs.dimensions(), rhs.dimensions());
+}
+
+/* static */ bool ShapeUtil::Compatible(const Shape& lhs, const Shape& rhs) {
+ if (lhs.element_type() == TUPLE) {
+ return rhs.element_type() == TUPLE &&
+ ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), Compatible);
+ }
+ return SameDimensions(lhs, rhs) && SameElementType(lhs, rhs);
+}
+
+/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
+ int64 dimension_number) {
+ return shape.dimensions(GetDimensionNumber(shape, dimension_number));
+}
+
+/* static */ int64 ShapeUtil::GetDimensionNumber(const Shape& shape,
+ int64 dimension_number) {
+ if (dimension_number < 0) {
+ dimension_number += ShapeUtil::Rank(shape);
+ }
+ CHECK_GE(dimension_number, 0);
+ return dimension_number;
+}
+
+/* static */ int64 ShapeUtil::ByteSizeOfPrimitiveType(
+ PrimitiveType primitive_type) {
+ switch (primitive_type) {
+ case PRED:
+ return sizeof(int8);
+ case TUPLE:
+ LOG(FATAL) << "tuples have no definitive size";
+ case OPAQUE:
+ LOG(FATAL) << "opaque have no definitive size";
+ case S8:
+ return sizeof(int8);
+ case S16:
+ return sizeof(int16);
+ case S32:
+ return sizeof(int32);
+ case S64:
+ return sizeof(int64);
+ case U8:
+ return sizeof(uint8);
+ case U16:
+ return sizeof(uint16);
+ case U32:
+ return sizeof(uint32);
+ case U64:
+ return sizeof(uint64);
+ case F16:
+ return sizeof(float) / 2;
+ case F32:
+ return sizeof(float);
+ case F64:
+ return sizeof(double);
+ default:
+ LOG(FATAL) << "Unhandled primitive type " << primitive_type;
+ }
+}
+
+/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape,
+ int64 pointer_size) {
+ TF_DCHECK_OK(ValidateShape(shape));
+ DCHECK_NE(OPAQUE, shape.element_type());
+ if (shape.element_type() == TUPLE) {
+ return pointer_size * shape.tuple_shapes_size();
+ }
+ int64 allocated_element_count;
+ if (shape.layout().padded_dimensions_size() > 0) {
+ CHECK_EQ(ShapeUtil::Rank(shape), shape.layout().padded_dimensions_size());
+ allocated_element_count = 1;
+ for (int64 dimension_size : shape.layout().padded_dimensions()) {
+ allocated_element_count *= dimension_size;
+ }
+ } else {
+ allocated_element_count = ElementsIn(shape);
+ }
+ return allocated_element_count *
+ ByteSizeOfPrimitiveType(shape.element_type());
+}
+
+/* static */ int64 ShapeUtil::ByteSizeOf(const Shape& shape) {
+ return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*));
+}
+
+/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
+ const Shape& shape) {
+ if (shape.element_type() == TUPLE) {
+ // Tuple shape.
+ if (ShapeUtil::Rank(shape) != 0) {
+ return InvalidArgument("tuples must be rank-0; got rank %lld",
+ ShapeUtil::Rank(shape));
+ }
+ if (shape.dimensions_size() != 0) {
+ return InvalidArgument("tuples must not have dimensions specified");
+ }
+ for (auto& element_shape : shape.tuple_shapes()) {
+ TF_RETURN_IF_ERROR(
+ ValidateShapeWithOptionalLayoutInternal(element_shape));
+ }
+ return Status::OK();
+ }
+
+ // Non-tuple shape.
+ if (shape.tuple_shapes_size() > 0) {
+ return InvalidArgument("non-tuple shape has tuple_shapes field");
+ }
+ if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+ return InvalidArgument("shape has invalid element type: %s",
+ shape.ShortDebugString().c_str());
+ }
+ if (ShapeUtil::Rank(shape) != shape.dimensions_size()) {
+ return InvalidArgument(
+ "shape's rank is mismatched with dimension count; rank=%lld "
+ "dimensions_size=%d",
+ ShapeUtil::Rank(shape), shape.dimensions_size());
+ }
+ for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) {
+ int64 dimension = shape.dimensions(i);
+ if (dimension < 0) {
+ return InvalidArgument(
+ "shape's dimensions must not be < 0; dimension at index %lld was "
+ "%lld",
+ i, dimension);
+ }
+ }
+
+ return Status::OK();
+}
+
+/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayout(
+ const Shape& shape) {
+ if (LayoutUtil::HasLayout(shape)) {
+ // Since a layout is present, upgrade to the full set of invariant checks.
+ return ValidateShape(shape);
+ }
+ return ValidateShapeWithOptionalLayoutInternal(shape);
+}
+
+/* static */ Status ShapeUtil::ValidateShape(const Shape& shape) {
+ TF_RETURN_IF_ERROR(ValidateShapeWithOptionalLayoutInternal(shape));
+
+ return LayoutUtil::ValidateLayoutInShape(shape);
+}
+
+/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape,
+ PrimitiveType type) {
+ Shape new_shape = shape;
+ new_shape.set_element_type(type);
+ return new_shape;
+}
+
+/* static */ const Shape& ShapeUtil::GetSubshape(const Shape& shape,
+ const ShapeIndex& index) {
+ const Shape* return_shape = &shape;
+ for (auto i : index) {
+ CHECK(IsTuple(*return_shape));
+ return_shape = &return_shape->tuple_shapes(i);
+ }
+ return *return_shape;
+}
+
+/* static */ Shape* ShapeUtil::GetMutableSubshape(Shape* shape,
+ const ShapeIndex& index) {
+ Shape* return_shape = shape;
+ for (auto i : index) {
+ CHECK(IsTuple(*return_shape));
+ return_shape = return_shape->mutable_tuple_shapes(i);
+ }
+ return return_shape;
+}
+
+/* static */ Shape ShapeUtil::StripDegenerateDimensions(const Shape& shape) {
+ std::vector<int64> dimension_sizes;
+ std::vector<int64> degenerate_dimensions;
+ for (int64 i = 0; i < shape.dimensions_size(); ++i) {
+ if (shape.dimensions(i) == 1) {
+ degenerate_dimensions.push_back(i);
+ } else {
+ dimension_sizes.push_back(shape.dimensions(i));
+ }
+ }
+
+ // Construct minor_to_major of stripped shape. The order of the non-degenerate
+ // dimensions should be preserved from the original shape. First, create
+ // vector of the non-degenerate dimensions from the original minor_to_major
+ // array.
+ std::vector<int64> minor_to_major;
+ for (int64 i : shape.layout().minor_to_major()) {
+ if (std::find(degenerate_dimensions.begin(), degenerate_dimensions.end(),
+ i) == degenerate_dimensions.end()) {
+ minor_to_major.push_back(i);
+ }
+ }
+
+ // The dimensions in minor_to_major need to be renumbered to account for the
+ // degenerate dimensions which have removed. Decrement each dimension number
+ // once for each degenerate dimension which has a smaller number.
+ for (int i = 0; i < minor_to_major.size(); ++i) {
+ int adjustment = 0;
+ for (int64 dim : degenerate_dimensions) {
+ if (minor_to_major[i] > dim) {
+ adjustment++;
+ }
+ }
+ minor_to_major[i] -= adjustment;
+ }
+
+ {
+ std::vector<int64> dims(minor_to_major.size());
+ std::iota(dims.begin(), dims.end(), 0);
+ DCHECK(minor_to_major.size() == dims.size() &&
+ std::is_permutation(minor_to_major.begin(), minor_to_major.end(),
+ dims.begin()));
+ }
+ Shape stripped_shape =
+ shape.has_layout() ? MakeShapeWithLayout(shape.element_type(),
+ dimension_sizes, minor_to_major)
+ : MakeShape(shape.element_type(), dimension_sizes);
+
+ VLOG(10) << "Original_shape: " << HumanStringWithLayout(shape);
+ VLOG(10) << "Stripped_shape: " << HumanStringWithLayout(stripped_shape);
+ return stripped_shape;
+}
+
+namespace {
+
+// Helper for ForEachSubshape which visits the subshapes of the given shape in
+// DFS pre-order starting with the index.
+Status ForEachSubshapeHelper(const Shape& shape,
+ const ShapeUtil::VisitorFunction func,
+ ShapeIndex* index) {
+ TF_RETURN_IF_ERROR(func(shape, *index));
+ if (ShapeUtil::IsTuple(shape)) {
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachSubshapeHelper(
+ ShapeUtil::GetTupleElementShape(shape, i), func, index));
+ index->pop_back();
+ }
+ }
+ return Status::OK();
+}
+
+// Helper for ForEachMutableSubshape which visits the subshapes of the given
+// shape in DFS pre-order starting with the index.
+Status ForEachMutableSubshapeHelper(
+ Shape* shape, const ShapeUtil::MutatingVisitorFunction func,
+ ShapeIndex* index) {
+ TF_RETURN_IF_ERROR(func(shape, *index));
+ if (ShapeUtil::IsTuple(*shape)) {
+ for (int64 i = 0; i < ShapeUtil::TupleElementCount(*shape); ++i) {
+ index->push_back(i);
+ TF_RETURN_IF_ERROR(ForEachMutableSubshapeHelper(
+ shape->mutable_tuple_shapes(i), func, index));
+ index->pop_back();
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+/* static */ Status ShapeUtil::ForEachSubshape(const Shape& shape,
+ VisitorFunction func) {
+ ShapeIndex index;
+ return ForEachSubshapeHelper(shape, func, &index);
+}
+
+/* static */ Status ShapeUtil::ForEachMutableSubshape(
+ Shape* shape, MutatingVisitorFunction func) {
+ ShapeIndex index;
+ return ForEachMutableSubshapeHelper(shape, func, &index);
+}
+
+/* static */ Shape ShapeUtil::PermuteDimensions(
+ tensorflow::gtl::ArraySlice<int64> permutation, const Shape& shape) {
+ Shape new_shape = shape;
+ new_shape.clear_dimensions();
+ for (auto dim : Permute(permutation, shape.dimensions())) {
+ new_shape.add_dimensions(dim);
+ }
+ if (shape.has_layout()) {
+ new_shape.mutable_layout()->clear_minor_to_major();
+ for (auto index : Permute(permutation, shape.layout().minor_to_major())) {
+ new_shape.mutable_layout()->add_minor_to_major(index);
+ }
+ }
+ return new_shape;
+}
+
+/* static */ std::tuple<bool, std::vector<int64>, std::vector<int64>>
+ShapeUtil::InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
+ const Shape& shape_post) {
+ auto nil = std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
+
+ std::vector<int64> deleted_indices;
+ std::vector<int64> inserted_indices;
+ // Returns false if any input/output index between prior_unmodified_dim_pair
+ // and unmodified_dim_pair have size >1. Otherwise, returns true and appends
+ // the degerenate input/output dimensions in the gap to
+ // deleted_indices/inserted_indices respectively.
+ auto check_modified_dims = [&shape_pre, &shape_post, &deleted_indices,
+ &inserted_indices](
+ std::pair<int64, int64> prior_unmodified_dim_pair,
+ std::pair<int64, int64> unmodified_dim_pair) {
+ for (int64 modified_input_dim = prior_unmodified_dim_pair.first + 1;
+ modified_input_dim < unmodified_dim_pair.first; ++modified_input_dim) {
+ if (shape_pre.dimensions(modified_input_dim) > 1) {
+ return false;
+ }
+ deleted_indices.push_back(modified_input_dim);
+ }
+ for (int64 modified_output_dim = prior_unmodified_dim_pair.second + 1;
+ modified_output_dim < unmodified_dim_pair.second;
+ ++modified_output_dim) {
+ if (shape_post.dimensions(modified_output_dim) > 1) {
+ return false;
+ }
+ inserted_indices.push_back(modified_output_dim);
+ }
+ return true;
+ };
+
+ std::vector<std::pair<int64, int64>> unmodified_dims =
+ DimensionsUnmodifiedByReshape(shape_pre, shape_post);
+ // Returns nil if the reshape modifies any non-degenerate input/output
+ // dimension. DimensionsUnmodifiedByReshape gives us all unmodified
+ // dimensions, so we only need to check whether dimensions in the gaps (thus
+ // modified) have size >1.
+ for (size_t i = 0; i <= unmodified_dims.size(); ++i) {
+ // Check (modified) dimensions between unmodified_dims[i-1] and
+ // unmodified_dims[i].
+ auto prior_unmodified_dim_pair =
+ i > 0 ? unmodified_dims[i - 1] : std::make_pair(-1LL, -1LL);
+ auto unmodified_dim_pair =
+ i < unmodified_dims.size()
+ ? unmodified_dims[i]
+ : std::make_pair(ShapeUtil::Rank(shape_pre),
+ ShapeUtil::Rank(shape_post));
+ if (!check_modified_dims(prior_unmodified_dim_pair, unmodified_dim_pair)) {
+ return nil;
+ }
+ }
+
+ return std::make_tuple(true, deleted_indices, inserted_indices);
+}
+
+/* static */ std::vector<std::pair<int64, int64>>
+ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
+ const Shape& output_shape) {
+ // Returns nil if the input/output shape has zero elements. This is safe but
+ // might be too conservative. Not a big deal for now because IR emitted for
+ // zero-element shapes are often trivially optimizable without the help of
+ // this method.
+ if (ShapeUtil::ElementsIn(input_shape) == 0 ||
+ ShapeUtil::ElementsIn(output_shape) == 0) {
+ return std::vector<std::pair<int64, int64>>();
+ }
+
+ std::vector<std::pair<int64, int64>> unmodified_dims;
+ int64 input_dim = 0;
+ int64 output_dim = 0;
+
+ // A reshape preserves input_dim as output_dim iff
+ // 1. input_dim and output_dim have the same size.
+ // 2. The size of the input subarray from dimension 0 to input_dim-1 equals
+ // that of the output subarray from dimension 0 to output_dim-1.
+ VLOG(3) << "DimensionsUnmodifiedByReshape: input_shape="
+ << ShapeUtil::HumanString(input_shape)
+ << ", output_shape=" << ShapeUtil::HumanString(output_shape);
+ while (input_dim < ShapeUtil::Rank(input_shape) &&
+ output_dim < ShapeUtil::Rank(output_shape)) {
+ // partial_input_size is the product of sizes of input dimensions
+ // inclusively between the input_dim when this loop iteration starts and the
+ // current input_dim. partial_output_size is that of output dimensions. We
+ // compute these two values incrementally to save time.
+ int64 partial_input_size = input_shape.dimensions(input_dim);
+ int64 partial_output_size = output_shape.dimensions(output_dim);
+ // Move input_dim and output_dim forward until
+ // partial_input_size==partial_output_size.
+ while (partial_input_size != partial_output_size) {
+ if (partial_input_size < partial_output_size) {
+ ++input_dim;
+ partial_input_size *= input_shape.dimensions(input_dim);
+ } else {
+ ++output_dim;
+ partial_output_size *= output_shape.dimensions(output_dim);
+ }
+ }
+ CHECK_LT(input_dim, ShapeUtil::Rank(input_shape));
+ CHECK_LT(output_dim, ShapeUtil::Rank(output_shape));
+ if (input_shape.dimensions(input_dim) ==
+ output_shape.dimensions(output_dim)) {
+ unmodified_dims.push_back({input_dim, output_dim});
+ VLOG(3) << "Matching dimension pair: " << input_dim << ' ' << output_dim;
+ }
+ ++input_dim;
+ ++output_dim;
+ }
+
+ return unmodified_dims;
+}
+
+/* static */ bool ShapeUtil::TransposeIsBitcast(
+ const Shape& input_shape, const Shape& output_shape,
+ tensorflow::gtl::ArraySlice<int64> dimension_mapping) {
+ // Can't insert bitcasts without layout information.
+ if (!LayoutUtil::HasLayout(input_shape) &&
+ !LayoutUtil::HasLayout(output_shape)) {
+ return false;
+ }
+
+ // Padding is not handled.
+ if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) {
+ return false;
+ }
+
+ // Check the reshape permutes the positions of each dimension in the
+ // minor-to-major order. positions[i]=k means dimension `i` is k-th minor.
+ // input_positions = apply(dimension_mapping, output_positions)
+ //
+ // Because the positions of each dimension are the inverse permutation of the
+ // minor-to-major order, the above check is equivalent to
+ // inverse(input_dimensions) =
+ // apply(dimension_mapping, inverse(output_dimensions))
+ // # `I` indicates identity permutation.
+ // apply(input_dimensions, I) =
+ // apply(dimension_mapping, apply(output_dimensions, I))
+ // apply(input_dimensions, I) =
+ // apply((dimension_mapping * output_dimensions), I)
+ // input_dimensions = dimension_mapping * output_dimensions
+ return ContainersEqual(
+ ComposePermutations(dimension_mapping,
+ AsInt64Slice(output_shape.layout().minor_to_major())),
+ input_shape.layout().minor_to_major());
+}
+
+/* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape,
+ const Shape& output_shape) {
+ // Can't convert reshapes into bitcasts without layout information.
+ if (!LayoutUtil::HasLayout(input_shape) ||
+ !LayoutUtil::HasLayout(output_shape)) {
+ return false;
+ }
+
+ // Padding is not handled.
+ if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) {
+ return false;
+ }
+
+ CHECK_EQ(ShapeUtil::ElementsIn(input_shape),
+ ShapeUtil::ElementsIn(output_shape));
+ if (ShapeUtil::ElementsIn(input_shape) == 0) {
+ return true;
+ }
+
+ // TL;DR: The rest of the method checks that the reshape does not change the
+ // physical location of any unit input or output index. Unit indices have
+ // exactly one dimension that equals 1 and other dimensions 0. This condition
+ // is necessary for the reshape to be a bitcast, because a bitcast-equivalent
+ // reshape shouldn't change the physical location of any element. It is also a
+ // sufficient condition as is proved below (note: many details are omitted for
+ // space).
+ //
+ // Definitions:
+ //
+ // * Denote the input shape by IS and output shape by OS. IS[i] or OS[i] means
+ // the size of i-th least significant dimension of IS or OS (this is opposite
+ // to how we define the index of Shape::dimensions()).
+ //
+ // * Given an input or output index I, denote by p(I) I's physical linear
+ // index (or physical index for short) and l(I) I's logical linear index (or
+ // logical index for short).
+ //
+ // * Given a logical index k, denote by II(k) the input index whose linear
+ // index is k, and OI(k) the corresponding output index.
+ //
+ // * Denote by IT[i] the increment of physical index if i-th dimension of the
+ // input index is increased by 1. Similarly, OT[i] means the increment if i-th
+ // dimension of the output index is increased by 1. Note that IT[i] or OT[i]
+ // is a function of IS or OS and the layout, and not dependent on the specific
+ // input or output index.
+ //
+ // To prove the reshape from IS to OS is a bitcast, it is sufficient to prove
+ // that, for any linear index k, p(II(k))=p(OI(k)). We prove this by
+ // induction. We know p(II(0))=p(OI(0)) is trivially true, so what's left is
+ // to prove, with every increment on k, the above formula still holds.
+ //
+ // First, suppose reshaping from IS to OS is non-factorizable (we discuss
+ // refactorizable reshapes later). A reshape from IS to OS is factorizable, if
+ // there exists (i,j) such that
+ //
+ // 0<=i<=|IS|
+ // 0<=j<=|OS|
+ // |IS|-i+|OS|-j > 0 (i.e., i,j mustn't both point to the end)
+ // product(IS[i], IS[i+1], ..., IS[|IS|-1])
+ // = product(OS[j], OS[j+1], ..., OS[|OS|-1])
+ //
+ // p(II(k))=p(OI(k)) is trivially true for k=0 because p(II(0)) and p(OI(0))
+ // are both 0. It's also trivially true for k=1, because II(1) and OI(1) are
+ // unit indices which are already tested. This also means IT[0]=OT[0]
+ // because p(II(1))=IT[0] and p(OI(1))=OT[0].
+ //
+ // Furthermore, p(II(k))=p(OI(k)) for k<min(IS[0],OS[0]), because each
+ // increment of k adds IT[0] to the input physical and OT[0] (same as IT[0])
+ // to the output physical.
+ //
+ // When k=min(IS[0],OS[0]), the first wrap happens. Without losing generality,
+ // suppose IS[0]<OS[0] and thus k=IS[0]. Similar proof applies to IS[0]>OS[0].
+ // Note that IS[0]!=OS[0] because the reshape is non-factorizable. From
+ // logical index k-1 to logical index k, dimension 1 of the input index
+ // is increased by 1 and dimension 0 is reset to 0 thus decreased by
+ // IS[0]-1. Therefore, the physical input index is increased by
+ //
+ // p(II(k)) - p(II(k-1)) = IT[1] - (IS[0]-1) * IT[0]
+ //
+ // Because IS[0]<OS[0], the only change to the output index is that its
+ // dimension 0 is increased by one. Therefore,
+ //
+ // p(OI(k)) - p(OI(k-1)) = OT[0] = IT[0]
+ //
+ // Because II(k) is an unit index -- (0,..,0,1,0), we already tested that
+ // p(II(k))=p(OI(k)). Therefore,
+ // IT[1] - (IS[0]-1) * IT[0] = IT[0]
+ // IT[1] = IS[0] * IT[0]
+ // In other words, input dimension 1 is immediately more major than input
+ // dimension 0. We can now conceptually collapse these two dimensions because
+ // an increment in the logical index affecting only these two dimensions maps
+ // to IT[0] in the physical index.
+ //
+ // By induction (omitted here), we can prove IT[i]=IS[i-1]*IT[i-1] and
+ // OT[i]=OS[i-1]*OT[i-1]. Therefore, both IS and OS are row-major and bitwise
+ // identical.
+ //
+ // A factorizable reshape can be factorized into a list of non-factorizable
+ // sub-reshapes, each of which can be handled similarly to the proof above.
+ // For example,
+ //
+ // [7x9x2x15] -> [63x6x5]
+ //
+ // can be factorized into
+ //
+ // [7x9] -> [63] and [2x15] -> [6x5].
+ //
+ // Suppose input index I=(x3,x2,x1,x0) and output index O=(y2,y1,y0) have the
+ // same logical linear index. According to the factorization, we know
+ // l(x3,x2,0,0)=l(y2,0,0) and l(0,0,x1,x0)=l(0,y1,y0). Using the proof for
+ // non-factorizable reshapes, we can prove p(0,0,x1,x0)=p(0,y1,y0). Using a
+ // similar proof, with the increment of the logical index set to
+ // IS[1]*IS[0]=OS[1]*OS[0]=30 instead of 1, we can prove
+ // p(x3,x2,0,0)=p(y2,0,0) too. Therefore,
+ //
+ // p(x3,x2,x1,x0) = p(x3,x2,0,0) + p(0,0,x1,x0)
+ // = p(y2,0,0) + p(0,0,y1,y0)
+ // = p(y2,y1,y0)
+ //
+ // check_input_unit_indices checks one way of the condition: each input unit
+ // index is mapped to an output index with the same physical location. This
+ // lambda will be called again with input_shape and output_shape reversed to
+ // check the other way.
+ auto check_input_unit_indices = [](const Shape& input_shape,
+ const Shape& output_shape) {
+ // input_shape_dim0_major/output_shape_dim0_major has the same "dimensions"
+ // as input_shape/output_shape and the dimension-0-major layout. These two
+ // shapes are used for conversion between logical linear indices and
+ // multi-dimensional indices.
+ Shape input_shape_dim0_major =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ input_shape.element_type(), AsInt64Slice(input_shape.dimensions()));
+ Shape output_shape_dim0_major =
+ ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
+ output_shape.element_type(),
+ AsInt64Slice(output_shape.dimensions()));
+
+ for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape);
+ ++input_dim) {
+ if (input_shape.dimensions(input_dim) <= 1) {
+ continue;
+ }
+
+ std::vector<int64> input_unit_index(ShapeUtil::Rank(input_shape), 0);
+ input_unit_index[input_dim] = 1;
+ int64 logical_linear_index =
+ IndexUtil::MultidimensionalIndexToLinearIndex(input_shape_dim0_major,
+ input_unit_index);
+ // output_index has the same logical linear index as input_unit_index.
+ std::vector<int64> output_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(output_shape_dim0_major,
+ logical_linear_index);
+ // Check input_unit_index and output_index have the same physical linear
+ // index.
+ if (IndexUtil::MultidimensionalIndexToLinearIndex(input_shape,
+ input_unit_index) !=
+ IndexUtil::MultidimensionalIndexToLinearIndex(output_shape,
+ output_index)) {
+ return false;
+ }
+ }
+ return true;
+ };
+ return check_input_unit_indices(input_shape, output_shape) &&
+ check_input_unit_indices(output_shape, input_shape);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
new file mode 100644
index 0000000000..35fd714b0b
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -0,0 +1,393 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Shapes are protobuf messages, so this utility header offers a bunch of
+// functionality for querying / poking at them.
+
+#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
+
+#include <initializer_list>
+#include <string>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// An index for specifying a particular nested subshape within a shape. Used in
+// ShapeUtil::GetSubshape and other interfaces. Shapes are recursive data
+// structures (trees) and ShapeIndex defines a path through the tree where each
+// element of ShapeIndex indexes into a tuple (or nested tuple) within the
+// shape. For a non-nested tuple, an index has a single element. For example,
+// given a 3-element tuple (a, b, c) containing arrays a, b, and c, the index
+// {1} corresponds to array b. For a nested tuple, the index can have more than
+// one element. For the nested tuple (a, (b, c, d), e) below are the values
+// corresponding to the given indices:
+//
+// index {0} : array a
+// index {1, 2} : array d
+// index {2} : array e
+// index {0, 0} : invalid index (element at {0} is an array not a tuple)
+//
+// For indexing into array shapes, the index is always trivially empty, ie {}.
+//
+// ShapeIndex is a trivial wrapper around std::vector with a minimum number of
+// methods implemented.
+class ShapeIndex {
+ public:
+ ShapeIndex() = default;
+ ShapeIndex(std::initializer_list<int64> init) : indices_(init) {}
+
+ bool empty() const { return indices_.empty(); }
+ size_t size() const { return indices_.size(); }
+ void push_back(int64 value) { indices_.push_back(value); }
+ void pop_back() { indices_.pop_back(); }
+
+ std::vector<int64>::const_iterator begin() const { return indices_.begin(); }
+ std::vector<int64>::const_iterator end() const { return indices_.end(); }
+ std::vector<int64>::iterator begin() { return indices_.begin(); }
+ std::vector<int64>::iterator end() { return indices_.end(); }
+
+ const int64& operator[](size_t i) const { return indices_[i]; }
+ int64& operator[](size_t i) { return indices_[i]; }
+
+ bool operator==(const ShapeIndex& other) const {
+ return indices_ == other.indices_;
+ }
+ bool operator!=(const ShapeIndex& other) const { return !(*this == other); }
+
+ private:
+ std::vector<int64> indices_;
+};
+
+// Namespaced collection of (static) shape utilities.
+//
+// These are all effectively convenience functions for testing/tweaking proto
+// properties, which do invariant checks before / after the operation.
+class ShapeUtil {
+ public:
+ // Returns the number of elements are contained within the provided shape;
+ // e.g. for rank 0 (scalars) the result is always 1.
+ static int64 ElementsIn(const Shape& shape);
+
+ // Returns true if 'shape' has zero elements.
+ static bool HasZeroElements(const Shape& shape);
+
+ // Returns the number of bytes required for an allocation of shape. The
+ // |pointer_size| parameter is used for calculating the size of tuple
+ // shapes. This includes only the size of the top-level buffer. For example, a
+ // tuple is stored as an array of pointers to other buffers. In this case,
+ // this method only returns the size of the pointer array.
+ static int64 ByteSizeOf(const Shape& shape, int64 pointer_size);
+
+ // Returns the number of bytes required for an allocation of shape.
+ // The calculation for tuple shapes assumes that we are utilizing host
+ // pointers.
+ // Precondition: !ShapeUtil::IsOpaque(shape)
+ static int64 ByteSizeOf(const Shape& shape);
+
+ // Returns the number of bytes used to store the primitive_type.
+ //
+ // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
+ static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type);
+
+ // Returns a human-readable string that represents the given shape, with or
+ // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]".
+ static string HumanString(const Shape& shape);
+ static string HumanStringWithLayout(const Shape& shape);
+
+ // As above, but for program shapes, returns a string for the form:
+ //
+ // (param_name: f32[42x12], ...) -> f32[24x42]
+ static string HumanString(const ProgramShape& shape);
+
+ // Parses a ShapeUtil::HumanString-format shape string back into a shape
+ // object.
+ static StatusOr<Shape> ParseShapeString(const string& s);
+
+ // Returns whether the LHS and RHS shapes have the same dimensions; note: does
+ // not check element type.
+ static bool SameDimensions(const Shape& lhs, const Shape& rhs);
+
+ // Returns whether the lhs and rhs shapes have the same element type.
+ static bool SameElementType(const Shape& lhs, const Shape& rhs) {
+ return lhs.element_type() == rhs.element_type();
+ }
+
+ // Returns true if the rank, dimension sizes, and element type are
+ // identical. Layout is ignored. Tuple elements are compared recursively for
+ // compatibility.
+ static bool Compatible(const Shape& lhs, const Shape& rhs);
+
+ // Returns whether the lhs and rhs shapes are identical protobufs.
+ static bool Equal(const Shape& lhs, const Shape& rhs);
+
+ // Returns the rank (number of dimensions) of the given shape.
+ static int64 Rank(const Shape& shape) { return shape.dimensions_size(); }
+
+ // Returns the number of dimensions for which the dimension is not (trivially)
+ // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just
+ // fluff. Note that zero dimensions are included in the true rank, e.g.,
+ // f32[3,0,1] has a true rank of 2D.
+ static int64 TrueRank(const Shape& shape);
+
+ static ProgramShape MakeProgramShape(std::initializer_list<Shape> parameters,
+ Shape result);
+
+ ////////////////////
+ // Scalar-specific
+
+ static bool IsScalar(const Shape& shape) {
+ return !IsTuple(shape) && !IsOpaque(shape) && Rank(shape) == 0;
+ }
+ static bool IsEffectiveScalar(const Shape& shape) {
+ return !IsTuple(shape) && !IsOpaque(shape) && TrueRank(shape) == 0;
+ }
+ static bool IsScalarF32(const Shape& shape);
+
+ // Extracts the size of the shape's dimension at dimension number
+ // GetDimensionNumber(dimension_number).
+ static int64 GetDimension(const Shape& shape, int64 dimension_number);
+
+ // Resolves a dimension number, supporting negative indexing.
+ //
+ // Negative indexing has similar semantics to Python. For an N-dimensional
+ // array, dimension -1 is equivalent to dimension N-1, -2 is equivalent to
+ // N-2, and so on.
+ //
+ // This function always returns a positive dimension number for any given
+ // dimension_number (which itself can be negative).
+ static int64 GetDimensionNumber(const Shape& shape, int64 dimension_number);
+
+ // Returns a shape with the same dimensions as the original, but with the
+ // element type changed to type.
+ static Shape ChangeElementType(const Shape& original, PrimitiveType type);
+
+ // Creates a tuple shape from a slice of element shapes within the tuple.
+ static Shape MakeTupleShape(tensorflow::gtl::ArraySlice<Shape> shapes);
+
+ // Creates an opaque shape. These are generally used for threading a context
+ // into a custom operation.
+ static Shape MakeOpaqueShape();
+
+ // Appends a shape to the given tuple.
+ static void AppendShapeToTuple(const Shape& shape, Shape* tuple_shape);
+
+ // Appends a major dimension to the shape with the given bound.
+ static void AppendMajorDimension(int bound, Shape* shape);
+
+ // Returns an empty tuple shape. Can be used to indicate side-effects.
+ static Shape MakeNil() { return MakeTupleShape({}); }
+
+ // Constructs a new shape with the given element type and sequence of
+ // dimensions.
+ static Shape MakeShape(PrimitiveType element_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Constructs a new shape with the given minor_to_major order in its Layout.
+ // Returns a value shape such that shape.has_layout().
+ static Shape MakeShapeWithLayout(
+ PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major);
+
+ // Constructs a new shape with major-first layout.
+ static Shape MakeShapeWithMonotonicDim0MajorLayout(
+ PrimitiveType element_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions);
+
+ // Returns a new shape with major-first layout that has the same layout of
+ // elements with a different shape.
+ static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape);
+
+ // As MakeShape, but the object to write to is passed in.
+ static void PopulateShape(PrimitiveType element_type,
+ tensorflow::gtl::ArraySlice<int64> dimensions,
+ Shape* shape);
+
+ // Validates that the provided shape satisfies invariants.
+ static Status ValidateShape(const Shape& shape);
+
+ // Validates the the provided shape satisfies invariants, except those that
+ // pertain to layout.
+ //
+ // Layout is optional for client-provided shapes, so that the compiler may
+ // determine and assign an optimized layout.
+ static Status ValidateShapeWithOptionalLayout(const Shape& shape);
+
+ // Returns whether the element type of the shape is integral (signed or
+ // unsigned). Note that predicates are not considered integral here, since
+ // they are logical values.
+ static bool ElementIsIntegral(const Shape& shape);
+
+ // Returns whether the element type of the shape is floating point.
+ static bool ElementIsFloating(const Shape& shape);
+
+ // Returns whether the element type has the given bit width.
+ static bool ElementHasBitWidth(const Shape& shape, int bits);
+
+ // Returns whether the element type of the shape is integral and has
+ // the specified number of bits.
+ static bool ElementIsIntegralWithBits(const Shape& shape, int bits);
+
+ // Returns whether the element type of the shape is signed. Note
+ // that floating point numbers are signed.
+ static bool ElementIsSigned(const Shape& shape);
+
+ // Returns whether the shape is a tuple.
+ static bool IsTuple(const Shape& shape);
+
+ // Returns whether the shape is an array.
+ static bool IsArray(const Shape& shape);
+
+ // Returns whether the shape is an opaque.
+ static bool IsOpaque(const Shape& shape);
+
+ // Returns whether the shape is a tuple with at least one element which is
+ // also a tuple.
+ static bool IsNestedTuple(const Shape& shape);
+
+ // Returns true if shape is an empty tuple.
+ static bool IsEmptyTuple(const Shape& shape);
+
+ // Returns true if shape is an empty tuple, or is an array with no elements.
+ static bool IsNil(const Shape& shape);
+
+ // Returns the number of elements in the given tuple shape.
+ // Precondition: IsTuple(shape)
+ static int64 TupleElementCount(const Shape& shape);
+
+ // Returns the tuple element shape at given index.
+ // Precondition: IsTuple(shape) && TupleElementCount(shape) > index
+ static const Shape& GetTupleElementShape(const Shape& shape, int64 index);
+
+ // Slices tuple elements in the range [start, limit) and returns a new tuple
+ // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32).
+ static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit);
+
+ // Shorthand for testing whether a shape is of a given element type and
+ // sequence of dimensions.
+ static bool ShapeIs(const Shape& shape, PrimitiveType element_type,
+ std::initializer_list<int64> dimensions);
+
+ // GetSubshape and GetMutableSubshape return a particular nested Shape within
+ // the given Shape argument.
+ static const Shape& GetSubshape(const Shape& shape, const ShapeIndex& index);
+ static Shape* GetMutableSubshape(Shape* shape, const ShapeIndex& index);
+
+ // Calls the given visitor function for each subshape of the given shape.
+ // Returns early if an error status is returned. Subshapes are visited in DFS
+ // pre-order starting with the entire shape (index {}).
+ using VisitorFunction = std::function<Status(const Shape& /*subshape*/,
+ const ShapeIndex& /*index*/)>;
+ static Status ForEachSubshape(const Shape& shape, VisitorFunction func);
+
+ // Mutating variant of ForEachSubshape.
+ using MutatingVisitorFunction =
+ std::function<Status(Shape* /*subshape*/, const ShapeIndex& /*index*/)>;
+ static Status ForEachMutableSubshape(Shape* shape,
+ MutatingVisitorFunction func);
+
+ // Removes all degenerate dimensions (size one) from the given shape. The
+ // stripped minor_to_major preserves the relative ordering of non-degenerate
+ // dimensions. The stripped shape has the property that the underlying
+ // representation (bits in memory) for the stripped shape is the same as the
+ // original shape modulo padding. Examples:
+ //
+ // input shape: F32 [1, 2, 1], minor_to_major = {0, 1, 2}
+ // stripped shape: F32 [2], minor_to_major = {0}
+ //
+ // input shape: F32 [6, 1, 5], minor_to_major = {2, 0, 1}
+ // stripped shape: F32 [6, 5], minor_to_major = {1, 0}
+ //
+ // input shape: F32 [1, 7, 1, 6, 5, 1], minor_to_major = {0, 2, 5, 4, 3, 1}
+ // stripped shape: F32 [7, 6, 5], minor_to_major = {0, 2, 1}
+ //
+ // input shape: F32 [1, 1], minor_to_major = {0, 1}
+ // stripped shape: F32 [], minor_to_major = {}
+ // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape)
+ static Shape StripDegenerateDimensions(const Shape& shape);
+
+ // Permutes the dimensions by the given permutation, so
+ // return_value.dimensions[permutation[i]] = argument.dimensions[i]
+ static Shape PermuteDimensions(tensorflow::gtl::ArraySlice<int64> permutation,
+ const Shape& shape);
+
+ // If we can go from `shape_pre` to `shape_post` by merely inserting or
+ // deleting 1-sized dimensions, return the indices in `shape_pre` of the
+ // deleted dimensions and the indices in `dims_post` of the inserted
+ // dimensions.
+ // For example, if `shape_pre = {a_1, a_2, ..., a_m}` and
+ // `shape_post = {b_1, b_2, ..., b_n}` where we can find some sequence of `i`s
+ // and some sequence of `j`s so `a_i = 1` for each `i` and `b_j = 1` for each
+ // `j` and `a_(k-s) = b_(k-t)` where `s` and `t` are the number of `i`s and
+ // `j`s less than `k` for all other `k`, we return the `i`s and `j`s.
+ // For another example, if `shape_pre = shape_post = {}`, we return `{}`.
+ static std::tuple<bool, std::vector<int64>, std::vector<int64>>
+ InsertedOrDeleted1SizedDimensions(const Shape& shape_pre,
+ const Shape& shape_post);
+
+ // Suppose a reshape transforms input_shape to output shape. Returns a vector
+ // of pairs that indicate the input and output dimensions that this reshape
+ // doesn't logically (i.e. ignoring the layout) modify. For each pair (I,O) in
+ // the returned vector, the reshape transforms any input index whose I-th
+ // dimension is x to an output index whose O-th dimension is x too.
+ //
+ // Post-condition: the returned vector is sorted (by both input and output
+ // dimensions because input and output dimensions have the same order).
+ //
+ // Example:
+ // input shape = T[a, b, x, y, cd]
+ // output shape = T[ab, x, 1, y, c, d]
+ // return value = {{2, 1}, {3, 3}}
+ //
+ // The two pairs represent the input and output dimension of size x and
+ // those of size y.
+ static std::vector<std::pair<int64, int64>> DimensionsUnmodifiedByReshape(
+ const Shape& input_shape, const Shape& output_shape);
+
+ // Returns whether a transpose from input_shape to output_shape with dimension
+ // mapping "dimension_mapping" produces a result which is bit-wise identical
+ // to its input and thus may be replaced with a bitcast.
+ static bool TransposeIsBitcast(
+ const Shape& input_shape, const Shape& output_shape,
+ tensorflow::gtl::ArraySlice<int64> dimension_mapping);
+
+ // Returns whether a reshape from "input_shape" to "output_shape" is a
+ // bitcast.
+ static bool ReshapeIsBitcast(const Shape& input_shape,
+ const Shape& output_shape);
+
+ private:
+ // Recursive helper for comparing the equality of two shapes. Returns true if
+ // the shapes are the same. If compare_layouts is true, then layouts must also
+ // match.
+ static bool CompareShapes(const Shape& lhs, const Shape& rhs,
+ bool compare_layouts);
+
+ // Validates all of the non-layout properties of the shape -- this is a helper
+ // used by both the layout-optional and layout-required public method.
+ static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc
new file mode 100644
index 0000000000..4e8a496e7e
--- /dev/null
+++ b/tensorflow/compiler/xla/shape_util_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/shape_util.h"
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(ShapeUtilTest, GetDimensionHelperCanNegativeIndex) {
+ Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
+ EXPECT_EQ(3, ShapeUtil::GetDimension(matrix, -1));
+ EXPECT_EQ(2, ShapeUtil::GetDimension(matrix, -2));
+}
+
+TEST(ShapeUtilTest, GetDimensionHelperExampleInDocumentationTest) {
+ auto shape = ShapeUtil::MakeShape(F32, {1, 2, 3, 4});
+ ASSERT_EQ(4, ShapeUtil::GetDimension(shape, -1));
+}
+
+TEST(ShapeUtilTest, NegativeIndexOobFails) {
+ Shape matrix = ShapeUtil::MakeShape(F32, {2, 3});
+ ASSERT_DEATH(ShapeUtil::GetDimension(matrix, -3), "dimension_number >= 0");
+}
+
+TEST(ShapeUtilTest, Rank1DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3});
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, Rank2DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 2});
+ ASSERT_EQ(2, shape.dimensions(1));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, Rank3DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7});
+ ASSERT_EQ(7, shape.dimensions(2));
+ ASSERT_EQ(2, shape.dimensions(1));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, Rank4DimensionIndexing) {
+ Shape shape = ShapeUtil::MakeShape(F32, {3, 2, 7, 8});
+ ASSERT_EQ(8, shape.dimensions(3));
+ ASSERT_EQ(7, shape.dimensions(2));
+ ASSERT_EQ(2, shape.dimensions(1));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+TEST(ShapeUtilTest, ParseShapeStringR2F32) {
+ string shape_string = "f32[123,456]";
+ Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie();
+ Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
+ ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
+ << "expected: " << ShapeUtil::HumanString(expected)
+ << "actual: " << ShapeUtil::HumanString(actual);
+}
+
+TEST(ShapeUtilTest, CompatibleIdenticalShapes) {
+ Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2});
+ Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
+ ASSERT_TRUE(ShapeUtil::Compatible(shape1, shape2));
+}
+
+TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
+ Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
+ auto layout_1 = shape_1.mutable_layout();
+ layout_1->clear_minor_to_major();
+ layout_1->add_minor_to_major(0);
+ layout_1->add_minor_to_major(1);
+
+ Shape shape_2 = ShapeUtil::MakeShape(F32, {3, 2});
+ auto layout_2 = shape_2.mutable_layout();
+ layout_2->clear_minor_to_major();
+ layout_2->add_minor_to_major(1);
+ layout_2->add_minor_to_major(0);
+
+ EXPECT_FALSE(ShapeUtil::Equal(shape_1, shape_2));
+ EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
+}
+
+TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
+ Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
+ Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
+ EXPECT_FALSE(ShapeUtil::Compatible(shape_1, shape_2));
+}
+
+TEST(ShapeUtilTest, CompatibleTuples) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
+ EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(PRED, {4, 5})});
+ EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(S32, {3, 2})});
+ EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) {
+ Shape tuple1 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
+ Shape tuple2 = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {4, 2})});
+ EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2));
+}
+
+TEST(ShapeUtilTest, EmptyLayoutEqualsMissingLayout) {
+ // A shape with a missing layout should be equal to a shape with an empty
+ // layout.
+ Shape scalar1 = ShapeUtil::MakeShape(F32, {});
+ Shape scalar2 = ShapeUtil::MakeShape(F32, {});
+
+ EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2));
+
+ scalar1.clear_layout(); // Remove layout field.
+ scalar2.mutable_layout(); // Create empty layout field.
+
+ EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2));
+}
+
+TEST(ShapeUtilTest, ScalarUnpopulatedLayoutEqualsScalarLayout) {
+ Shape scalar_unpopulated = ShapeUtil::MakeShape(F32, {});
+ scalar_unpopulated.clear_layout();
+ ASSERT_FALSE(scalar_unpopulated.has_layout())
+ << ShapeUtil::HumanStringWithLayout(scalar_unpopulated);
+
+ const Shape scalar_populated = ShapeUtil::MakeShapeWithLayout(F32, {}, {});
+ ASSERT_TRUE(scalar_populated.has_layout())
+ << ShapeUtil::HumanStringWithLayout(scalar_populated);
+
+ EXPECT_TRUE(ShapeUtil::Equal(scalar_unpopulated, scalar_populated));
+}
+
+TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) {
+ EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32));
+ EXPECT_EQ(4, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {})));
+ EXPECT_EQ(800, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F32, {10, 20})));
+
+ EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64));
+ EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {})));
+ EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20})));
+}
+
+TEST(ShapeUtilTest, ByteSizeOfWithPadding) {
+ EXPECT_EQ(4, ShapeUtil::ByteSizeOfPrimitiveType(F32));
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 20});
+ EXPECT_EQ(800, ShapeUtil::ByteSizeOf(shape));
+
+ shape.mutable_layout()->add_padded_dimensions(15);
+ shape.mutable_layout()->add_padded_dimensions(21);
+ EXPECT_EQ(15 * 21 * 4, ShapeUtil::ByteSizeOf(shape));
+}
+
+TEST(ShapeUtilTest, NestedTuple) {
+ EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape({})));
+ EXPECT_FALSE(ShapeUtil::IsNestedTuple(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeTupleShape({})})));
+ EXPECT_FALSE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeTupleShape({})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeShape(S32, {})})));
+ EXPECT_TRUE(ShapeUtil::IsNestedTuple(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({}), ShapeUtil::MakeTupleShape({})})));
+}
+
+TEST(ShapeUtilTest, ElementsIn) {
+ EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_EQ(1, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_EQ(2, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_EQ(0, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_EQ(15, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_EQ(221, ShapeUtil::ElementsIn(ShapeUtil::MakeShape(S32, {13, 17})));
+}
+
+TEST(ShapeUtilTest, HasZeroElements) {
+ EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {})));
+ EXPECT_EQ(true, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 1})));
+ EXPECT_EQ(false, ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {2, 1})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 5})));
+ EXPECT_EQ(true,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {3, 0, 5})));
+ EXPECT_EQ(true,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {0, 3, 0})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {1, 3, 5})));
+ EXPECT_EQ(false,
+ ShapeUtil::HasZeroElements(ShapeUtil::MakeShape(S32, {13, 17})));
+}
+
+TEST(ShapeUtilTest, SameDimensions) {
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(S32, {})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F32, {})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(S32, {1})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0}),
+ ShapeUtil::MakeShape(S32, {0})));
+ EXPECT_TRUE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {2}),
+ ShapeUtil::MakeShape(S32, {2})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {2})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {0, 0}),
+ ShapeUtil::MakeShape(F32, {0})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F32, {1})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {1, 1})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1}),
+ ShapeUtil::MakeShape(F32, {1, 0})));
+ EXPECT_FALSE(ShapeUtil::SameDimensions(ShapeUtil::MakeShape(S32, {1, 1}),
+ ShapeUtil::MakeShape(F32, {1, 2})));
+}
+
+TEST(ShapeUtilTest, GetSubshape) {
+ // Test array shape.
+ Shape array_shape = ShapeUtil::MakeShape(F32, {42, 42, 123});
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(array_shape, {})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ array_shape, *ShapeUtil::GetMutableSubshape(&array_shape, {})));
+
+ // Test tuple shape.
+ Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({array_shape, array_shape, array_shape});
+ EXPECT_TRUE(
+ ShapeUtil::Equal(tuple_shape, ShapeUtil::GetSubshape(tuple_shape, {})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {0})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {1})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(array_shape, ShapeUtil::GetSubshape(tuple_shape, {2})));
+
+ // Test nested tuple shape.
+ Shape nested_tuple_shape = ShapeUtil::MakeTupleShape(
+ {array_shape, ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ array_shape})});
+ EXPECT_TRUE(ShapeUtil::Equal(nested_tuple_shape,
+ ShapeUtil::GetSubshape(nested_tuple_shape, {})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ array_shape, ShapeUtil::GetSubshape(nested_tuple_shape, {0})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::GetSubshape(nested_tuple_shape, {1})));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeTupleShape({array_shape, array_shape}),
+ ShapeUtil::GetSubshape(nested_tuple_shape, {2, 0})));
+}
+
+TEST(ShapeUtilTest, HumanString) {
+ Shape opaque = ShapeUtil::MakeOpaqueShape();
+ Shape scalar = ShapeUtil::MakeShape(F32, {});
+ Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
+ Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
+ Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
+ Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix});
+
+ EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
+ EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
+ EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
+ EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
+ EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
+ ShapeUtil::HumanString(tuple));
+ EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ ShapeUtil::HumanString(nested_tuple));
+
+ EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
+ EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar));
+ EXPECT_EQ("u32[1,2] {1,0}", ShapeUtil::HumanStringWithLayout(matrix));
+ EXPECT_EQ("s32[3,4] {0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
+ EXPECT_EQ("(opaque[], f32[], u32[1,2] {1,0}, s32[3,4] {0,1})",
+ ShapeUtil::HumanStringWithLayout(tuple));
+ EXPECT_EQ(
+ "((opaque[], f32[], u32[1,2] {1,0}, s32[3,4] {0,1}), u32[1,2] {1,0})",
+ ShapeUtil::HumanStringWithLayout(nested_tuple));
+
+ ProgramShape prog = ShapeUtil::MakeProgramShape(
+ {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
+ EXPECT_EQ(
+ "((unknown): opaque[], "
+ "(unknown): f32[], "
+ "(unknown): u32[1,2], "
+ "(unknown): s32[3,4], "
+ "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
+ "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ ShapeUtil::HumanString(prog));
+
+ prog.add_parameter_names("arg0");
+ prog.add_parameter_names("scalar");
+ prog.add_parameter_names("matrix");
+ prog.add_parameter_names("matrix2");
+ prog.add_parameter_names("tuple");
+ prog.add_parameter_names("nested_tuple");
+ EXPECT_EQ(
+ "(arg0: opaque[], "
+ "scalar: f32[], "
+ "matrix: u32[1,2], "
+ "matrix2: s32[3,4], "
+ "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
+ "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])) -> "
+ "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2])",
+ ShapeUtil::HumanString(prog));
+}
+
+TEST(ShapeUtilTest, ForEachSubshapeArray) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
+ int calls = 0;
+ EXPECT_IS_OK(ShapeUtil::ForEachSubshape(
+ shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
+ EXPECT_EQ(&shape, &subshape);
+ EXPECT_TRUE(index.empty());
+ ++calls;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(1, calls);
+}
+
+TEST(ShapeUtilTest, ForEachSubshapeNestedTuple) {
+ const Shape shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {42}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
+ ShapeUtil::MakeShape(PRED, {33})})});
+ int calls = 0;
+ EXPECT_IS_OK(ShapeUtil::ForEachSubshape(
+ shape, [&calls, &shape](const Shape& subshape, const ShapeIndex& index) {
+ EXPECT_TRUE(
+ ShapeUtil::Equal(subshape, ShapeUtil::GetSubshape(shape, index)));
+ if (calls == 0) {
+ // Visitation should go from outside in.
+ EXPECT_TRUE(index.empty());
+ } else if (calls == 4) {
+ // Last visitation should be to the array with 33 elements.
+ EXPECT_EQ(33, ShapeUtil::ElementsIn(subshape));
+ }
+ ++calls;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(5, calls);
+}
+
+TEST(ShapeUtilTest, ForEachMutableSubshapeNestedTuple) {
+ Shape shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(F32, {42}),
+ ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {101}),
+ ShapeUtil::MakeShape(PRED, {33})})});
+ int calls = 0;
+ EXPECT_IS_OK(ShapeUtil::ForEachMutableSubshape(
+ &shape, [&calls, &shape](const Shape* subshape, const ShapeIndex& index) {
+ // Pointer values should be equal
+ EXPECT_EQ(subshape, ShapeUtil::GetMutableSubshape(&shape, index));
+ if (calls == 0) {
+ // Visitation should go from outside in.
+ EXPECT_TRUE(index.empty());
+ } else if (calls == 4) {
+ // Last visitation should be to the array with 33 elements.
+ EXPECT_EQ(33, ShapeUtil::ElementsIn(*subshape));
+ }
+ ++calls;
+ return tensorflow::Status::OK();
+ }));
+ EXPECT_EQ(5, calls);
+}
+
+TEST(ShapeUtilTest, InsertedOrDeleted1SizedDimensions) {
+ Shape shape0 = ShapeUtil::MakeShape(S32, {9, 1, 4});
+ Shape shape1 = ShapeUtil::MakeShape(S32, {1, 9, 4, 1});
+ Shape shape2 = ShapeUtil::MakeShape(S32, {3, 1, 12});
+ EXPECT_TRUE(std::get<0>(
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape1)));
+ EXPECT_FALSE(std::get<0>(
+ ShapeUtil::InsertedOrDeleted1SizedDimensions(shape0, shape2)));
+}
+
+TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
+ // All output dimensions should be unmodified. One of the input dimensions is
+ // modified because the input rank is larger by one.
+ EXPECT_EQ(3,
+ ShapeUtil::DimensionsUnmodifiedByReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1, 1}),
+ ShapeUtil::MakeShape(S32, {1, 1, 1}))
+ .size());
+}
+
+TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1_to_1x1x1x1) {
+ // All input dimensions should be unmodified. One of the output dimensions is
+ // modified because the output rank is larger by one.
+ EXPECT_EQ(3,
+ ShapeUtil::DimensionsUnmodifiedByReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1}),
+ ShapeUtil::MakeShape(S32, {1, 1, 1, 1}))
+ .size());
+}
+
+TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_4x1x3x5x6x7_to_2x6x1x5x1x42) {
+ // The only matching dimension is the one with size 5.
+ // 4, 1, 3, 5, 6, 7
+ // |
+ // 2, 6, 1, 5, 1, 42
+ EXPECT_TRUE(
+ ContainersEqual(ShapeUtil::DimensionsUnmodifiedByReshape(
+ ShapeUtil::MakeShape(S32, {4, 1, 3, 5, 6, 7}),
+ ShapeUtil::MakeShape(S32, {2, 6, 1, 5, 1, 42})),
+ std::vector<std::pair<int64, int64>>({{3, 3}})));
+}
+
+TEST(ShapeUtilTest, ReshapeIsBitcast_3x4_6x2) {
+ for (bool input_is_row_major : {true, false}) {
+ for (bool output_is_row_major : {true, false}) {
+ Layout input_layout = input_is_row_major ? LayoutUtil::MakeLayout({1, 0})
+ : LayoutUtil::MakeLayout({0, 1});
+ Layout output_layout = output_is_row_major
+ ? LayoutUtil::MakeLayout({1, 0})
+ : LayoutUtil::MakeLayout({0, 1});
+ // Suppose the input is logically (i.e. ignoring its layout)
+ // 0 1 2 3
+ // 4 5 6 7
+ // 8 9 10 11
+ //
+ // The reshape transforms the input to logically
+ // 0 1
+ // 2 3
+ // 4 5
+ // 6 7
+ // 8 9
+ // 10 11
+ //
+ // The input and the output have the same underlying data only if they
+ // are both row-major.
+ EXPECT_EQ(
+ ShapeUtil::ReshapeIsBitcast(
+ ShapeUtil::MakeShapeWithLayout(
+ F32, {3, 4}, AsInt64Slice(input_layout.minor_to_major())),
+ ShapeUtil::MakeShapeWithLayout(
+ F32, {6, 2}, AsInt64Slice(output_layout.minor_to_major()))),
+ input_is_row_major && output_is_row_major);
+ }
+ }
+}
+
+TEST(ShapeUtilTest, ReshapeIsBitcast_3x2x2_6x2_Dim1IsMostMinor) {
+ EXPECT_TRUE(ShapeUtil::ReshapeIsBitcast(
+ ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {1, 0, 2}),
+ ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
+}
+
+TEST(AlgebraicSimplifierTest, ReshapeIsBitcast_3x2x2_6x2_Dim0IsMostMinor) {
+ EXPECT_FALSE(ShapeUtil::ReshapeIsBitcast(
+ ShapeUtil::MakeShapeWithLayout(F32, {3, 2, 2}, {0, 1, 2}),
+ ShapeUtil::MakeShapeWithLayout(F32, {6, 2}, {0, 1})));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/status.h b/tensorflow/compiler/xla/status.h
new file mode 100644
index 0000000000..f3b561fada
--- /dev/null
+++ b/tensorflow/compiler/xla/status.h
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_STATUS_H_
+#define TENSORFLOW_COMPILER_XLA_STATUS_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+#if defined(__clang__)
+// Only clang supports warn_unused_result as a type annotation.
+class TF_MUST_USE_RESULT Status;
+#endif
+
+// Simple wrapper around tensorflow::Status that has the MUST_USE_RESULT
+// annotation above. When tensorflow::Status adopts this annotation, this can
+// simply become a "using tensorflow::Status".
+class Status : public tensorflow::Status {
+ public:
+ static Status OK() { return tensorflow::Status::OK(); }
+
+ // Note: implicit constructor.
+ Status(tensorflow::Status other) : tensorflow::Status(other) {}
+
+ Status() : tensorflow::Status() {}
+ Status(tensorflow::error::Code code, tensorflow::StringPiece msg)
+ : tensorflow::Status(code, msg) {}
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_STATUS_H_
diff --git a/tensorflow/compiler/xla/status_macros.cc b/tensorflow/compiler/xla/status_macros.cc
new file mode 100644
index 0000000000..a6b1f9004f
--- /dev/null
+++ b/tensorflow/compiler/xla/status_macros.cc
@@ -0,0 +1,170 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/status_macros.h"
+
+#include <algorithm>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/stacktrace.h"
+
+namespace xla {
+namespace status_macros {
+
+static Status MakeStatus(tensorflow::error::Code code, const string& message) {
+ return Status(code, message);
+}
+
+// Log the error at the given severity, optionally with a stack trace.
+// If log_severity is NUM_SEVERITIES, nothing is logged.
+static void LogError(const Status& status, const char* filename, int line,
+ int log_severity, bool should_log_stack_trace) {
+ if (TF_PREDICT_TRUE(log_severity != tensorflow::NUM_SEVERITIES)) {
+ string stack_trace;
+ if (should_log_stack_trace) {
+ stack_trace =
+ tensorflow::strings::StrCat("\n", tensorflow::CurrentStackTrace());
+ }
+ switch (log_severity) {
+ case tensorflow::INFO:
+ LOG(INFO) << status << stack_trace;
+ break;
+ case tensorflow::WARNING:
+ LOG(WARNING) << status << stack_trace;
+ break;
+ case tensorflow::ERROR:
+ LOG(ERROR) << status << stack_trace;
+ break;
+ case tensorflow::FATAL:
+ LOG(FATAL) << status << stack_trace;
+ break;
+ case tensorflow::NUM_SEVERITIES:
+ break;
+ default:
+ LOG(FATAL) << "Unknown LOG severity " << log_severity;
+ }
+ }
+}
+
+// Make a Status with a code, error message and payload,
+// and also send it to LOG(<log_severity>) using the given filename
+// and line (unless should_log is false, or log_severity is
+// NUM_SEVERITIES). If should_log_stack_trace is true, the stack
+// trace is included in the log message (ignored if should_log is
+// false).
+static Status MakeError(const char* filename, int line,
+ tensorflow::error::Code code, const string& message,
+ bool should_log, int log_severity,
+ bool should_log_stack_trace) {
+ if (TF_PREDICT_FALSE(code == tensorflow::error::OK)) {
+ LOG(ERROR) << "Cannot create error with status OK";
+ code = tensorflow::error::UNKNOWN;
+ }
+ const Status status = MakeStatus(code, message);
+ if (TF_PREDICT_TRUE(should_log)) {
+ LogError(status, filename, line, log_severity, should_log_stack_trace);
+ }
+ return status;
+}
+
+// This method is written out-of-line rather than in the header to avoid
+// generating a lot of inline code for error cases in all callers.
+void MakeErrorStream::CheckNotDone() const { impl_->CheckNotDone(); }
+
+MakeErrorStream::Impl::Impl(const char* file, int line,
+ tensorflow::error::Code code,
+ MakeErrorStream* error_stream,
+ bool is_logged_by_default)
+ : file_(file),
+ line_(line),
+ code_(code),
+ is_done_(false),
+ should_log_(is_logged_by_default),
+ log_severity_(tensorflow::ERROR),
+ should_log_stack_trace_(false),
+ make_error_stream_with_output_wrapper_(error_stream) {}
+
+MakeErrorStream::Impl::Impl(const Status& status,
+ PriorMessageHandling prior_message_handling,
+ const char* file, int line,
+ MakeErrorStream* error_stream)
+ : file_(file),
+ line_(line),
+ // Make sure we show some error, even if the call is incorrect.
+ code_(!status.ok() ? status.code() : tensorflow::error::UNKNOWN),
+ prior_message_handling_(prior_message_handling),
+ prior_message_(status.error_message()),
+ is_done_(false),
+ // Error code type is not visible here, so we can't call
+ // IsLoggedByDefault.
+ should_log_(true),
+ log_severity_(tensorflow::ERROR),
+ should_log_stack_trace_(false),
+ make_error_stream_with_output_wrapper_(error_stream) {
+ DCHECK(!status.ok()) << "Attempted to append/prepend error text to status OK";
+}
+
+MakeErrorStream::Impl::~Impl() {
+ // Note: error messages refer to the public MakeErrorStream class.
+
+ if (!is_done_) {
+ LOG(ERROR) << "MakeErrorStream destructed without getting Status: " << file_
+ << ":" << line_ << " " << stream_.str();
+ }
+}
+
+Status MakeErrorStream::Impl::GetStatus() {
+ // Note: error messages refer to the public MakeErrorStream class.
+
+ // Getting a Status object out more than once is not harmful, but
+ // it doesn't match the expected pattern, where the stream is constructed
+ // as a temporary, loaded with a message, and then casted to Status.
+ if (is_done_) {
+ LOG(ERROR) << "MakeErrorStream got Status more than once: " << file_ << ":"
+ << line_ << " " << stream_.str();
+ }
+
+ is_done_ = true;
+
+ const string& stream_str = stream_.str();
+ const string str =
+ prior_message_handling_ == kAppendToPriorMessage
+ ? tensorflow::strings::StrCat(prior_message_, stream_str)
+ : tensorflow::strings::StrCat(stream_str, prior_message_);
+ if (TF_PREDICT_FALSE(str.empty())) {
+ return MakeError(file_, line_, code_,
+ tensorflow::strings::StrCat(
+ str, "Error without message at ", file_, ":", line_),
+ true /* should_log */,
+ tensorflow::ERROR /* log_severity */,
+ should_log_stack_trace_);
+ } else {
+ return MakeError(file_, line_, code_, str, should_log_, log_severity_,
+ should_log_stack_trace_);
+ }
+}
+
+void MakeErrorStream::Impl::CheckNotDone() const {
+ if (is_done_) {
+ LOG(ERROR) << "MakeErrorStream shift called after getting Status: " << file_
+ << ":" << line_ << " " << stream_.str();
+ }
+}
+
+} // namespace status_macros
+} // namespace xla
diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h
new file mode 100644
index 0000000000..aa12cda666
--- /dev/null
+++ b/tensorflow/compiler/xla/status_macros.h
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_
+#define TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_
+
+#include <memory>
+#include <ostream> // NOLINT
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+namespace status_macros {
+
+// Stream object used to collect error messages in MAKE_ERROR macros
+// or append error messages with APPEND_ERROR. It accepts any
+// arguments with operator<< to build an error string, and then has an
+// implicit cast operator to Status, which converts the
+// logged string to a Status object and returns it, after logging the
+// error. At least one call to operator<< is required; a compile time
+// error will be generated if none are given. Errors will only be
+// logged by default for certain status codes, as defined in
+// IsLoggedByDefault. This class will give ERROR errors if you don't
+// retrieve a Status exactly once before destruction.
+//
+// The class converts into an intermediate wrapper object
+// MakeErrorStreamWithOutput to check that the error stream gets at least one
+// item of input.
+class MakeErrorStream {
+ public:
+ // Wrapper around MakeErrorStream that only allows for output. This
+ // is created as output of the first operator<< call on
+ // MakeErrorStream. The bare MakeErrorStream does not have a
+ // Status operator. The net effect of that is that you
+ // have to call operator<< at least once or else you'll get a
+ // compile time error.
+ class MakeErrorStreamWithOutput {
+ public:
+ explicit MakeErrorStreamWithOutput(MakeErrorStream* error_stream)
+ : wrapped_error_stream_(error_stream) {}
+
+ template <typename T>
+ MakeErrorStreamWithOutput& operator<<(const T& value) {
+ *wrapped_error_stream_ << value;
+ return *this;
+ }
+
+ // Implicit cast operators to Status and StatusOr.
+ // Exactly one of these must be called exactly once before destruction.
+ operator Status() { return wrapped_error_stream_->GetStatus(); }
+ template <typename T>
+ operator xla::StatusOr<T>() {
+ return wrapped_error_stream_->GetStatus();
+ }
+
+ private:
+ MakeErrorStream* wrapped_error_stream_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MakeErrorStreamWithOutput);
+ };
+
+ // When starting from an existing error status, this determines whether we'll
+ // append or prepend to that status's error message.
+ enum PriorMessageHandling { kAppendToPriorMessage, kPrependToPriorMessage };
+
+ // Make an error with the given code.
+ template <typename ERROR_CODE_TYPE>
+ MakeErrorStream(const char* file, int line, ERROR_CODE_TYPE code)
+ : impl_(new Impl(file, line, code, this, true)) {}
+
+ template <typename T>
+ MakeErrorStreamWithOutput& operator<<(const T& value) {
+ CheckNotDone();
+ impl_->stream_ << value;
+ return impl_->make_error_stream_with_output_wrapper_;
+ }
+
+ // When this message is logged (see with_logging()), include the stack trace.
+ MakeErrorStream& with_log_stack_trace() {
+ impl_->should_log_stack_trace_ = true;
+ return *this;
+ }
+
+ // Adds RET_CHECK failure text to error message.
+ MakeErrorStreamWithOutput& add_ret_check_failure(const char* condition) {
+ return *this << "RET_CHECK failure (" << impl_->file_ << ":" << impl_->line_
+ << ") " << condition << " ";
+ }
+
+ private:
+ class Impl {
+ public:
+ Impl(const char* file, int line, tensorflow::error::Code code,
+ MakeErrorStream* error_stream, bool is_logged_by_default = true);
+ Impl(const Status& status, PriorMessageHandling prior_message_handling,
+ const char* file, int line, MakeErrorStream* error_stream);
+
+ ~Impl();
+
+ // This must be called exactly once before destruction.
+ Status GetStatus();
+
+ void CheckNotDone() const;
+
+ private:
+ const char* file_;
+ int line_;
+ tensorflow::error::Code code_;
+
+ PriorMessageHandling prior_message_handling_ = kAppendToPriorMessage;
+ string prior_message_;
+ bool is_done_; // true after Status object has been returned
+ std::ostringstream stream_;
+ bool should_log_;
+ int log_severity_;
+ bool should_log_stack_trace_;
+
+ // Wrapper around the MakeErrorStream object that has a
+ // Status conversion. The first << operator called on
+ // MakeErrorStream will return this object, and only this object
+ // can implicitly convert to Status. The net effect of
+ // this is that you'll get a compile time error if you call
+ // MAKE_ERROR etc. without adding any output.
+ MakeErrorStreamWithOutput make_error_stream_with_output_wrapper_;
+
+ friend class MakeErrorStream;
+ TF_DISALLOW_COPY_AND_ASSIGN(Impl);
+ };
+
+ void CheckNotDone() const;
+
+ // Returns the status. Used by MakeErrorStreamWithOutput.
+ Status GetStatus() const { return impl_->GetStatus(); }
+
+ // Store the actual data on the heap to reduce stack frame sizes.
+ std::unique_ptr<Impl> impl_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MakeErrorStream);
+};
+
+// Provides a conversion to bool so that it can be used inside an if statement
+// that declares a variable.
+class StatusAdaptorForMacros {
+ public:
+ explicit StatusAdaptorForMacros(Status status) : status_(std::move(status)) {}
+
+ StatusAdaptorForMacros(const StatusAdaptorForMacros&) = delete;
+ StatusAdaptorForMacros& operator=(const StatusAdaptorForMacros&) = delete;
+
+ explicit operator bool() const { return TF_PREDICT_TRUE(status_.ok()); }
+
+ Status&& Consume() { return std::move(status_); }
+
+ private:
+ Status status_;
+};
+
+} // namespace status_macros
+} // namespace xla
+
+#define TF_RET_CHECK(condition) \
+ while (TF_PREDICT_FALSE(!(condition))) \
+ return xla::status_macros::MakeErrorStream(__FILE__, __LINE__, \
+ tensorflow::error::INTERNAL) \
+ .with_log_stack_trace() \
+ .add_ret_check_failure(#condition)
+
+#define TF_ASSIGN_OR_ASSERT_OK(lhs, rexpr) \
+ TF_ASSIGN_OR_ASSERT_OK_IMPL( \
+ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
+ rexpr);
+
+#define TF_ASSIGN_OR_ASSERT_OK_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ ASSERT_TRUE(statusor.status().ok()) << statusor.status(); \
+ lhs = statusor.ConsumeValueOrDie()
+
+#define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y)
+#define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
+
+#define TF_ASSIGN_OR_RETURN(...) \
+ TF_STATUS_MACRO_GET_VARIADIC_IMPL(__VA_ARGS__, TF_ASSIGN_OR_RETURN_IMPL_3, \
+ TF_ASSIGN_OR_RETURN_IMPL_2) \
+ (__VA_ARGS__)
+
+#define TF_STATUS_MACRO_GET_VARIADIC_IMPL(_1, _2, _3, NAME, ...) NAME
+
+#define TF_ASSIGN_OR_RETURN_IMPL_2(lhs, rexpr) \
+ TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr)
+
+#define TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) \
+ TF_ASSIGN_OR_RETURN_IMPL( \
+ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr)
+
+#define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ if (TF_PREDICT_FALSE(!statusor.ok())) { \
+ return statusor.status(); \
+ } \
+ lhs = std::move(statusor.ValueOrDie())
+
+#endif // TENSORFLOW_COMPILER_XLA_STATUS_MACROS_H_
diff --git a/tensorflow/compiler/xla/status_macros_test.cc b/tensorflow/compiler/xla/status_macros_test.cc
new file mode 100644
index 0000000000..4e7b9161db
--- /dev/null
+++ b/tensorflow/compiler/xla/status_macros_test.cc
@@ -0,0 +1,112 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/status_macros.h"
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+Status RetCheckFail() {
+ TF_RET_CHECK(2 > 3);
+ return Status::OK();
+}
+
+Status RetCheckFailWithExtraMessage() {
+ TF_RET_CHECK(2 > 3) << "extra message";
+ return Status::OK();
+}
+
+Status RetCheckSuccess() {
+ TF_RET_CHECK(3 > 2);
+ return Status::OK();
+}
+
+TEST(StatusMacros, RetCheckFailing) {
+ Status status = RetCheckFail();
+ EXPECT_EQ(status.code(), tensorflow::error::INTERNAL);
+ EXPECT_MATCH(status.error_message(),
+ xla::testing::ContainsRegex("RET_CHECK failure.*2 > 3"));
+}
+
+TEST(StatusMacros, RetCheckFailingWithExtraMessage) {
+ Status status = RetCheckFailWithExtraMessage();
+ EXPECT_EQ(status.code(), tensorflow::error::INTERNAL);
+ EXPECT_MATCH(status.error_message(),
+ xla::testing::ContainsRegex("RET_CHECK.*2 > 3 extra message"));
+}
+
+TEST(StatusMacros, RetCheckSucceeding) {
+ Status status = RetCheckSuccess();
+ EXPECT_IS_OK(status);
+}
+
+StatusOr<int> CreateIntSuccessfully() { return 42; }
+
+StatusOr<int> CreateIntUnsuccessfully() {
+ return tensorflow::errors::Internal("foobar");
+}
+
+TEST(StatusMacros, AssignOrAssertOnOK) {
+ TF_ASSIGN_OR_ASSERT_OK(int result, CreateIntSuccessfully());
+ EXPECT_EQ(42, result);
+}
+
+Status ReturnStatusOK() { return Status::OK(); }
+
+Status ReturnStatusError() { return (tensorflow::errors::Internal("foobar")); }
+
+using StatusReturningFunction = std::function<Status()>;
+
+StatusOr<int> CallStatusReturningFunction(StatusReturningFunction func) {
+ TF_RETURN_IF_ERROR(func());
+ return 42;
+}
+
+TEST(StatusMacros, ReturnIfErrorOnOK) {
+ StatusOr<int> rc = CallStatusReturningFunction(ReturnStatusOK);
+ EXPECT_IS_OK(rc);
+ EXPECT_EQ(42, rc.ConsumeValueOrDie());
+}
+
+TEST(StatusMacros, ReturnIfErrorOnError) {
+ StatusOr<int> rc = CallStatusReturningFunction(ReturnStatusError);
+ EXPECT_FALSE(rc.ok());
+ EXPECT_EQ(rc.status().code(), tensorflow::error::INTERNAL);
+}
+
+TEST(StatusMacros, AssignOrReturnSuccessufully) {
+ Status status = []() {
+ TF_ASSIGN_OR_RETURN(int value, CreateIntSuccessfully());
+ EXPECT_EQ(value, 42);
+ return Status::OK();
+ }();
+ EXPECT_IS_OK(status);
+}
+
+TEST(StatusMacros, AssignOrReturnUnsuccessfully) {
+ Status status = []() {
+ TF_ASSIGN_OR_RETURN(int value, CreateIntUnsuccessfully());
+ (void)value;
+ return Status::OK();
+ }();
+ EXPECT_FALSE(status.ok());
+ EXPECT_EQ(status.code(), tensorflow::error::INTERNAL);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/statusor.cc b/tensorflow/compiler/xla/statusor.cc
new file mode 100644
index 0000000000..36f08fc99f
--- /dev/null
+++ b/tensorflow/compiler/xla/statusor.cc
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/statusor.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace internal {
+
+Status StatusOrHelper::HandleInvalidStatusCtorArg() {
+ const char* kMessage =
+ "Status::OK is not a valid constructor argument to StatusOr<T>";
+ LOG(ERROR) << kMessage;
+ // In optimized builds, we will fall back to tensorflow::error::INTERNAL.
+ return Status(tensorflow::error::INTERNAL, kMessage);
+}
+
+Status StatusOrHelper::HandleNullObjectCtorArg() {
+ const char* kMessage =
+ "NULL is not a valid constructor argument to StatusOr<T*>";
+ LOG(ERROR) << kMessage;
+ // In optimized builds, we will fall back to tensorflow::error::INTERNAL.
+ return Status(tensorflow::error::INTERNAL, kMessage);
+}
+
+void StatusOrHelper::Crash(const Status& status) {
+ LOG(FATAL) << "Attempting to fetch value instead of handling error "
+ << status;
+}
+
+} // namespace internal
+} // namespace xla
diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h
new file mode 100644
index 0000000000..8046a2216f
--- /dev/null
+++ b/tensorflow/compiler/xla/statusor.h
@@ -0,0 +1,300 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// StatusOr<T> is the union of a Status object and a T
+// object. StatusOr models the concept of an object that is either a
+// usable value, or an error Status explaining why such a value is
+// not present. To this end, StatusOr<T> does not allow its Status
+// value to be Status::OK. Furthermore, the value of a StatusOr<T*>
+// must not be null. This is enforced by a debug check in most cases,
+// but even when it is not, clients must not set the value to null.
+//
+// The primary use-case for StatusOr<T> is as the return value of a
+// function which may fail.
+//
+// Example client usage for a StatusOr<T>, where T is not a pointer:
+//
+// StatusOr<float> result = DoBigCalculationThatCouldFail();
+// if (result.ok()) {
+// float answer = result.ValueOrDie();
+// printf("Big calculation yielded: %f", answer);
+// } else {
+// LOG(ERROR) << result.status();
+// }
+//
+// Example client usage for a StatusOr<T*>:
+//
+// StatusOr<Foo*> result = FooFactory::MakeNewFoo(arg);
+// if (result.ok()) {
+// std::unique_ptr<Foo> foo(result.ValueOrDie());
+// foo->DoSomethingCool();
+// } else {
+// LOG(ERROR) << result.status();
+// }
+//
+// Example client usage for a StatusOr<std::unique_ptr<T>>:
+//
+// StatusOr<std::unique_ptr<Foo>> result = FooFactory::MakeNewFoo(arg);
+// if (result.ok()) {
+// std::unique_ptr<Foo> foo = std::move(result.ValueOrDie());
+// foo->DoSomethingCool();
+// } else {
+// LOG(ERROR) << result.status();
+// }
+//
+// Example factory implementation returning StatusOr<T*>:
+//
+// StatusOr<Foo*> FooFactory::MakeNewFoo(int arg) {
+// if (arg <= 0) {
+// return tensorflow::InvalidArgument("Arg must be positive");
+// } else {
+// return new Foo(arg);
+// }
+// }
+//
+// Note that the assignment operators require that destroying the currently
+// stored value cannot invalidate the argument; in other words, the argument
+// cannot be an alias for the current value, or anything owned by the current
+// value.
+#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_H_
+#define TENSORFLOW_COMPILER_XLA_STATUSOR_H_
+
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+#if defined(__clang__)
+// Only clang supports warn_unused_result as a type annotation.
+template <typename T, bool CopyConstructible>
+class TF_MUST_USE_RESULT StatusOr;
+#endif
+
+template <typename T,
+ bool CopyConstructible = std::is_copy_constructible<T>::value>
+class StatusOr {
+ template <typename U, bool UC>
+ friend class StatusOr;
+
+ public:
+ typedef T element_type;
+
+ // Construct a new StatusOr with Status::UNKNOWN status
+ StatusOr();
+
+ // Construct a new StatusOr with the given non-ok status. After calling
+ // this constructor, calls to ValueOrDie() will CHECK-fail.
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return
+ // value, so it is convenient and sensible to be able to do 'return
+ // Status()' when the return type is StatusOr<T>.
+ //
+ // REQUIRES: status != Status::OK. This requirement is DCHECKed.
+ // In optimized builds, passing Status::OK here will have the effect
+ // of passing tensorflow::error::INTERNAL as a fallback.
+ StatusOr(Status status); // NOLINT
+ StatusOr(tensorflow::Status status); // NOLINT
+
+ // Construct a new StatusOr with the given value. If T is a plain pointer,
+ // value must not be NULL. After calling this constructor, calls to
+ // ValueOrDie() will succeed, and calls to status() will return OK.
+ //
+ // NOTE: Not explicit - we want to use StatusOr<T> as a return type
+ // so it is convenient and sensible to be able to do 'return T()'
+ // when the return type is StatusOr<T>.
+ //
+ // REQUIRES: if T is a plain pointer, value != NULL. This requirement is
+ // DCHECKed. In optimized builds, passing a NULL pointer here will have
+ // the effect of passing tensorflow::error::INTERNAL as a fallback.
+ StatusOr(const T& value); // NOLINT
+
+ // Copy constructor.
+ StatusOr(const StatusOr& other) = default;
+
+ // Conversion copy constructor, T must be copy constructible from U
+ template <typename U>
+ StatusOr(const StatusOr<U>& other);
+
+ // Assignment operator.
+ StatusOr& operator=(const StatusOr& other) = default;
+
+ // Conversion assignment operator, T must be assignable from U
+ template <typename U>
+ StatusOr& operator=(const StatusOr<U>& other);
+
+ // Move constructor and move-assignment operator.
+ StatusOr(StatusOr&& other) = default;
+ StatusOr& operator=(StatusOr&& other) = default;
+
+ // Rvalue-reference overloads of the other constructors and assignment
+ // operators, to support move-only types and avoid unnecessary copying.
+ //
+ // Implementation note: we could avoid all these rvalue-reference overloads
+ // if the existing lvalue-reference overloads took their arguments by value
+ // instead. I think this would also let us omit the conversion assignment
+ // operator altogether, since we'd get the same functionality for free
+ // from the implicit conversion constructor and ordinary assignment.
+ // However, this could result in extra copy operations unless we use
+ // std::move to avoid them, and we can't use std::move because this code
+ // needs to be portable to C++03.
+ StatusOr(T&& value); // NOLINT
+ template <typename U>
+ StatusOr(StatusOr<U>&& other);
+
+ // Returns a reference to our status. If this contains a T, then
+ // returns Status::OK.
+ const Status& status() const { return status_; }
+
+ // Returns this->status().ok()
+ bool ok() const { return status_.ok(); }
+
+ // Returns a reference to our current value, or CHECK-fails if !this->ok().
+ const T& ValueOrDie() const;
+ T& ValueOrDie();
+
+ // Moves our current value out of this object and returns it, or CHECK-fails
+ // if !this->ok().
+ // Use of this method is discouraged; prefer std::move(statusor.ValueOrDie())
+ // instead.
+ T ConsumeValueOrDie() { return std::move(ValueOrDie()); }
+
+ private:
+ Status status_;
+ T value_;
+};
+
+// Partial specialization for when T is not copy-constructible. This uses all
+// methods from the core implementation, but removes copy assignment and copy
+// construction.
+template <typename T>
+class StatusOr<T, false> : public StatusOr<T, true> {
+ public:
+ // Remove copies.
+ StatusOr(const StatusOr& other) = delete;
+ StatusOr& operator=(const StatusOr& other) = delete;
+ template <typename U>
+ StatusOr(const StatusOr<U>& other) = delete;
+ StatusOr(const T& value) = delete;
+
+ // Use the superclass version for other constructors and operators.
+ StatusOr() = default;
+ StatusOr(StatusOr&& other) = default;
+ StatusOr& operator=(StatusOr&& other) = default;
+ StatusOr(T&& value) // NOLINT
+ : StatusOr<T, true>::StatusOr(std::move(value)) {}
+ StatusOr(Status status) // NOLINT
+ : StatusOr<T, true>::StatusOr(std::move(status)) {}
+ StatusOr(tensorflow::Status status) // NOLINT
+ : StatusOr<T, true>::StatusOr(std::move(status)) {}
+ template <typename U>
+ StatusOr(StatusOr<U>&& other) // NOLINT
+ : StatusOr<T, true>::StatusOr(std::move(other)) {}
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// Implementation details for StatusOr<T>
+
+namespace internal {
+
+class StatusOrHelper {
+ public:
+ // Move type-agnostic error handling to the .cc.
+ static Status HandleInvalidStatusCtorArg();
+ static Status HandleNullObjectCtorArg();
+ static void Crash(const Status& status);
+
+ // Customized behavior for StatusOr<T> vs. StatusOr<T*>
+ template <typename T>
+ struct Specialize;
+};
+
+template <typename T>
+struct StatusOrHelper::Specialize {
+ // For non-pointer T, a reference can never be NULL.
+ static inline bool IsValueNull(const T& t) { return false; }
+};
+
+template <typename T>
+struct StatusOrHelper::Specialize<T*> {
+ static inline bool IsValueNull(const T* t) { return t == NULL; }
+};
+
+} // namespace internal
+
+template <typename T, bool CopyConstructible>
+inline StatusOr<T, CopyConstructible>::StatusOr()
+ : status_(tensorflow::error::UNKNOWN, "") {}
+
+template <typename T, bool CopyConstructible>
+inline StatusOr<T, CopyConstructible>::StatusOr(Status status)
+ : status_(std::move(status)) {
+ if (status_.ok()) {
+ status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg();
+ }
+}
+
+template <typename T, bool CopyConstructible>
+inline StatusOr<T, CopyConstructible>::StatusOr(tensorflow::Status status)
+ : status_(status) {
+ if (status_.ok()) {
+ status_ = internal::StatusOrHelper::HandleInvalidStatusCtorArg();
+ }
+}
+
+template <typename T, bool CopyConstructible>
+inline StatusOr<T, CopyConstructible>::StatusOr(const T& value)
+ : value_(value) {
+ if (internal::StatusOrHelper::Specialize<T>::IsValueNull(value)) {
+ status_ = internal::StatusOrHelper::HandleNullObjectCtorArg();
+ }
+}
+
+template <typename T, bool CopyConstructible>
+template <typename U>
+inline StatusOr<T, CopyConstructible>::StatusOr(const StatusOr<U>& other)
+ : status_(other.status_), value_(other.value_) {}
+
+template <typename T, bool CopyConstructible>
+inline StatusOr<T, CopyConstructible>::StatusOr(T&& value)
+ : value_(std::move(value)) {
+ if (internal::StatusOrHelper::Specialize<T>::IsValueNull(value_)) {
+ status_ = internal::StatusOrHelper::HandleNullObjectCtorArg();
+ }
+}
+
+template <typename T, bool CopyConstructible>
+template <typename U>
+inline StatusOr<T, CopyConstructible>::StatusOr(StatusOr<U>&& other)
+ : status_(std::move(other.status_)), value_(std::move(other.value_)) {}
+
+template <typename T, bool CopyConstructible>
+inline const T& StatusOr<T, CopyConstructible>::ValueOrDie() const {
+ if (!ok()) {
+ internal::StatusOrHelper::Crash(status());
+ }
+ return value_;
+}
+
+template <typename T, bool CopyConstructible>
+inline T& StatusOr<T, CopyConstructible>::ValueOrDie() {
+ if (!status_.ok()) {
+ internal::StatusOrHelper::Crash(status());
+ }
+ return value_;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_H_
diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc
new file mode 100644
index 0000000000..d98eb27933
--- /dev/null
+++ b/tensorflow/compiler/xla/statusor_test.cc
@@ -0,0 +1,645 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Unit tests for StatusOr
+
+#include "tensorflow/compiler/xla/statusor.h"
+
+#include <memory>
+#include <type_traits>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace xla {
+namespace {
+
+using tensorflow::Status;
+
+class Base1 {
+ public:
+ virtual ~Base1() {}
+ int pad;
+};
+
+class Base2 {
+ public:
+ virtual ~Base2() {}
+ int yetotherpad;
+};
+
+class Derived : public Base1, public Base2 {
+ public:
+ virtual ~Derived() {}
+ int evenmorepad;
+};
+
+class CopyNoAssign {
+ public:
+ explicit CopyNoAssign(int value) : foo(value) {}
+ CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {}
+ int foo;
+
+ private:
+ const CopyNoAssign& operator=(const CopyNoAssign&);
+};
+
+StatusOr<std::unique_ptr<int>> ReturnUniquePtr() {
+ // Uses implicit constructor from T&&
+ return std::unique_ptr<int>(new int(0));
+}
+
+TEST(StatusOr, ElementType) {
+ static_assert(std::is_same<StatusOr<int>::element_type, int>(), "");
+ static_assert(std::is_same<StatusOr<char>::element_type, char>(), "");
+}
+
+TEST(StatusOr, TestMoveOnlyInitialization) {
+ StatusOr<std::unique_ptr<int>> thing(ReturnUniquePtr());
+ ASSERT_TRUE(thing.ok());
+ EXPECT_EQ(0, *thing.ValueOrDie());
+ int* previous = thing.ValueOrDie().get();
+
+ thing = ReturnUniquePtr();
+ EXPECT_TRUE(thing.ok());
+ EXPECT_EQ(0, *thing.ValueOrDie());
+ EXPECT_NE(previous, thing.ValueOrDie().get());
+}
+
+TEST(StatusOr, TestMoveOnlyStatusCtr) {
+ StatusOr<std::unique_ptr<int>> thing(tensorflow::errors::Cancelled(""));
+ ASSERT_FALSE(thing.ok());
+}
+
+TEST(StatusOr, TestMoveOnlyValueExtraction) {
+ StatusOr<std::unique_ptr<int>> thing(ReturnUniquePtr());
+ ASSERT_TRUE(thing.ok());
+ std::unique_ptr<int> ptr = thing.ConsumeValueOrDie();
+ EXPECT_EQ(0, *ptr);
+
+ thing = std::move(ptr);
+ ptr = std::move(thing.ValueOrDie());
+ EXPECT_EQ(0, *ptr);
+}
+
+TEST(StatusOr, TestMoveOnlyConversion) {
+ StatusOr<std::unique_ptr<const int>> const_thing(ReturnUniquePtr());
+ EXPECT_TRUE(const_thing.ok());
+ EXPECT_EQ(0, *const_thing.ValueOrDie());
+
+ // Test rvalue converting assignment
+ const int* const_previous = const_thing.ValueOrDie().get();
+ const_thing = ReturnUniquePtr();
+ EXPECT_TRUE(const_thing.ok());
+ EXPECT_EQ(0, *const_thing.ValueOrDie());
+ EXPECT_NE(const_previous, const_thing.ValueOrDie().get());
+}
+
+TEST(StatusOr, TestMoveOnlyVector) {
+ // Sanity check that StatusOr<MoveOnly> works in vector.
+ std::vector<StatusOr<std::unique_ptr<int>>> vec;
+ vec.push_back(ReturnUniquePtr());
+ vec.resize(2);
+ auto another_vec = std::move(vec);
+ EXPECT_EQ(0, *another_vec[0].ValueOrDie());
+ EXPECT_EQ(tensorflow::error::UNKNOWN, another_vec[1].status().code());
+}
+
+TEST(StatusOr, TestMoveWithValuesAndErrors) {
+ StatusOr<string> status_or(string(1000, '0'));
+ StatusOr<string> value1(string(1000, '1'));
+ StatusOr<string> value2(string(1000, '2'));
+ StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
+ StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
+
+ ASSERT_TRUE(status_or.ok());
+ EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie());
+
+ // Overwrite the value in status_or with another value.
+ status_or = std::move(value1);
+ ASSERT_TRUE(status_or.ok());
+ EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie());
+
+ // Overwrite the value in status_or with an error.
+ status_or = std::move(error1);
+ ASSERT_FALSE(status_or.ok());
+ EXPECT_EQ("error1", status_or.status().error_message());
+
+ // Overwrite the error in status_or with another error.
+ status_or = std::move(error2);
+ ASSERT_FALSE(status_or.ok());
+ EXPECT_EQ("error2", status_or.status().error_message());
+
+ // Overwrite the error with a value.
+ status_or = std::move(value2);
+ ASSERT_TRUE(status_or.ok());
+ EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie());
+}
+
+TEST(StatusOr, TestCopyWithValuesAndErrors) {
+ StatusOr<string> status_or(string(1000, '0'));
+ StatusOr<string> value1(string(1000, '1'));
+ StatusOr<string> value2(string(1000, '2'));
+ StatusOr<string> error1(Status(tensorflow::error::UNKNOWN, "error1"));
+ StatusOr<string> error2(Status(tensorflow::error::UNKNOWN, "error2"));
+
+ ASSERT_TRUE(status_or.ok());
+ EXPECT_EQ(string(1000, '0'), status_or.ValueOrDie());
+
+ // Overwrite the value in status_or with another value.
+ status_or = value1;
+ ASSERT_TRUE(status_or.ok());
+ EXPECT_EQ(string(1000, '1'), status_or.ValueOrDie());
+
+ // Overwrite the value in status_or with an error.
+ status_or = error1;
+ ASSERT_FALSE(status_or.ok());
+ EXPECT_EQ("error1", status_or.status().error_message());
+
+ // Overwrite the error in status_or with another error.
+ status_or = error2;
+ ASSERT_FALSE(status_or.ok());
+ EXPECT_EQ("error2", status_or.status().error_message());
+
+ // Overwrite the error with a value.
+ status_or = value2;
+ ASSERT_TRUE(status_or.ok());
+ EXPECT_EQ(string(1000, '2'), status_or.ValueOrDie());
+
+ // Verify original values unchanged.
+ EXPECT_EQ(string(1000, '1'), value1.ValueOrDie());
+ EXPECT_EQ("error1", error1.status().error_message());
+ EXPECT_EQ("error2", error2.status().error_message());
+ EXPECT_EQ(string(1000, '2'), value2.ValueOrDie());
+}
+
+TEST(StatusOr, TestDefaultCtor) {
+ StatusOr<int> thing;
+ EXPECT_FALSE(thing.ok());
+ EXPECT_EQ(thing.status().code(), tensorflow::error::UNKNOWN);
+}
+
+TEST(StatusOrDeathTest, TestDefaultCtorValue) {
+ StatusOr<int> thing;
+ EXPECT_DEATH(thing.ValueOrDie(), "");
+
+ const StatusOr<int> thing2;
+ EXPECT_DEATH(thing.ValueOrDie(), "");
+}
+
+TEST(StatusOr, TestStatusCtor) {
+ StatusOr<int> thing(Status(tensorflow::error::CANCELLED, ""));
+ EXPECT_FALSE(thing.ok());
+ EXPECT_EQ(thing.status().code(), tensorflow::error::CANCELLED);
+}
+
+TEST(StatusOr, TestValueCtor) {
+ const int kI = 4;
+ const StatusOr<int> thing(kI);
+ EXPECT_TRUE(thing.ok());
+ EXPECT_EQ(kI, thing.ValueOrDie());
+}
+
+TEST(StatusOr, TestCopyCtorStatusOk) {
+ const int kI = 4;
+ const StatusOr<int> original(kI);
+ const StatusOr<int> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+ EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie());
+}
+
+TEST(StatusOr, TestCopyCtorStatusNotOk) {
+ StatusOr<int> original(Status(tensorflow::error::CANCELLED, ""));
+ StatusOr<int> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+}
+
+TEST(StatusOr, TestCopyCtorNonAssignable) {
+ const int kI = 4;
+ CopyNoAssign value(kI);
+ StatusOr<CopyNoAssign> original(value);
+ StatusOr<CopyNoAssign> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+ EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo);
+}
+
+TEST(StatusOr, TestCopyCtorStatusOKConverting) {
+ const int kI = 4;
+ StatusOr<int> original(kI);
+ StatusOr<double> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+ EXPECT_DOUBLE_EQ(original.ValueOrDie(), copy.ValueOrDie());
+}
+
+TEST(StatusOr, TestCopyCtorStatusNotOkConverting) {
+ StatusOr<int> original(Status(tensorflow::error::CANCELLED, ""));
+ StatusOr<double> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+}
+
+TEST(StatusOr, TestAssignmentStatusOk) {
+ const int kI = 4;
+ StatusOr<int> source(kI);
+ StatusOr<int> target;
+ target = source;
+ EXPECT_EQ(target.status(), source.status());
+ EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie());
+}
+
+TEST(StatusOr, TestAssignmentStatusNotOk) {
+ StatusOr<int> source(Status(tensorflow::error::CANCELLED, ""));
+ StatusOr<int> target;
+ target = source;
+ EXPECT_EQ(target.status(), source.status());
+}
+
+TEST(StatusOr, TestStatus) {
+ StatusOr<int> good(4);
+ EXPECT_TRUE(good.ok());
+ StatusOr<int> bad(Status(tensorflow::error::CANCELLED, ""));
+ EXPECT_FALSE(bad.ok());
+ EXPECT_EQ(bad.status(), Status(tensorflow::error::CANCELLED, ""));
+}
+
+TEST(StatusOr, TestValue) {
+ const int kI = 4;
+ StatusOr<int> thing(kI);
+ EXPECT_EQ(kI, thing.ValueOrDie());
+}
+
+TEST(StatusOr, TestValueConst) {
+ const int kI = 4;
+ const StatusOr<int> thing(kI);
+ EXPECT_EQ(kI, thing.ValueOrDie());
+}
+
+TEST(StatusOrDeathTest, TestValueNotOk) {
+ StatusOr<int> thing(Status(tensorflow::error::CANCELLED, "cancelled"));
+ EXPECT_DEATH(thing.ValueOrDie(), "cancelled");
+}
+
+TEST(StatusOrDeathTest, TestValueNotOkConst) {
+ const StatusOr<int> thing(Status(tensorflow::error::UNKNOWN, ""));
+ EXPECT_DEATH(thing.ValueOrDie(), "");
+}
+
+TEST(StatusOr, TestPointerDefaultCtor) {
+ StatusOr<int*> thing;
+ EXPECT_FALSE(thing.ok());
+ EXPECT_EQ(thing.status().code(), tensorflow::error::UNKNOWN);
+}
+
+TEST(StatusOrDeathTest, TestPointerDefaultCtorValue) {
+ StatusOr<int*> thing;
+ EXPECT_DEATH(thing.ValueOrDie(), "");
+}
+
+TEST(StatusOr, TestPointerStatusCtor) {
+ StatusOr<int*> thing(Status(tensorflow::error::CANCELLED, ""));
+ EXPECT_FALSE(thing.ok());
+ EXPECT_EQ(thing.status(), Status(tensorflow::error::CANCELLED, ""));
+}
+
+TEST(StatusOr, TestPointerValueCtor) {
+ const int kI = 4;
+ StatusOr<const int*> thing(&kI);
+ EXPECT_TRUE(thing.ok());
+ EXPECT_EQ(&kI, thing.ValueOrDie());
+}
+
+TEST(StatusOr, TestPointerCopyCtorStatusOk) {
+ const int kI = 0;
+ StatusOr<const int*> original(&kI);
+ StatusOr<const int*> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+ EXPECT_EQ(original.ValueOrDie(), copy.ValueOrDie());
+}
+
+TEST(StatusOr, TestPointerCopyCtorStatusNotOk) {
+ StatusOr<int*> original(Status(tensorflow::error::CANCELLED, ""));
+ StatusOr<int*> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+}
+
+TEST(StatusOr, TestPointerCopyCtorStatusOKConverting) {
+ Derived derived;
+ StatusOr<Derived*> original(&derived);
+ StatusOr<Base2*> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+ EXPECT_EQ(static_cast<const Base2*>(original.ValueOrDie()),
+ copy.ValueOrDie());
+}
+
+TEST(StatusOr, TestPointerCopyCtorStatusNotOkConverting) {
+ StatusOr<Derived*> original(Status(tensorflow::error::CANCELLED, ""));
+ StatusOr<Base2*> copy(original);
+ EXPECT_EQ(copy.status(), original.status());
+}
+
+TEST(StatusOr, TestPointerAssignmentStatusOk) {
+ const int kI = 0;
+ StatusOr<const int*> source(&kI);
+ StatusOr<const int*> target;
+ target = source;
+ EXPECT_EQ(target.status(), source.status());
+ EXPECT_EQ(source.ValueOrDie(), target.ValueOrDie());
+}
+
+TEST(StatusOr, TestPointerAssignmentStatusNotOk) {
+ StatusOr<int*> source(Status(tensorflow::error::CANCELLED, ""));
+ StatusOr<int*> target;
+ target = source;
+ EXPECT_EQ(target.status(), source.status());
+}
+
+TEST(StatusOr, TestPointerStatus) {
+ const int kI = 0;
+ StatusOr<const int*> good(&kI);
+ EXPECT_TRUE(good.ok());
+ StatusOr<const int*> bad(Status(tensorflow::error::CANCELLED, ""));
+ EXPECT_EQ(bad.status(), Status(tensorflow::error::CANCELLED, ""));
+}
+
+TEST(StatusOr, TestPointerValue) {
+ const int kI = 0;
+ StatusOr<const int*> thing(&kI);
+ EXPECT_EQ(&kI, thing.ValueOrDie());
+}
+
+TEST(StatusOr, TestPointerValueConst) {
+ const int kI = 0;
+ const StatusOr<const int*> thing(&kI);
+ EXPECT_EQ(&kI, thing.ValueOrDie());
+}
+
+// NOTE(tucker): tensorflow::StatusOr does not support this kind
+// of resize op.
+// TEST(StatusOr, StatusOrVectorOfUniquePointerCanResize) {
+// using EvilType = std::vector<std::unique_ptr<int>>;
+// static_assert(std::is_copy_constructible<EvilType>::value, "");
+// std::vector<StatusOr<EvilType>> v(5);
+// v.reserve(v.capacity() + 10);
+// }
+
+TEST(StatusOrDeathTest, TestPointerValueNotOk) {
+ StatusOr<int*> thing(Status(tensorflow::error::CANCELLED, "cancelled"));
+ EXPECT_DEATH(thing.ValueOrDie(), "cancelled");
+}
+
+TEST(StatusOrDeathTest, TestPointerValueNotOkConst) {
+ const StatusOr<int*> thing(Status(tensorflow::error::CANCELLED, "cancelled"));
+ EXPECT_DEATH(thing.ValueOrDie(), "cancelled");
+}
+
+static StatusOr<int> MakeStatus() { return 100; }
+// A factory to help us benchmark the various factory styles. All of
+// the factory methods are marked as non-inlineable so as to more
+// accurately simulate calling a factory for which you do not have
+// visibility of implementation. Similarly, the value_ variable is
+// marked volatile to prevent the compiler from getting too clever
+// about detecting that the same value is used in all loop iterations.
+template <typename T>
+class BenchmarkFactory {
+ public:
+ // Construct a new factory. Allocate an object which will always
+ // be the result of the factory methods.
+ BenchmarkFactory() : value_(new T) {}
+
+ // Destroy this factory, including the result value.
+ ~BenchmarkFactory() { delete value_; }
+
+ // A trivial factory that just returns the value. There is no status
+ // object that could be returned to encapsulate an error
+ T* TrivialFactory() TF_ATTRIBUTE_NOINLINE { return value_; }
+
+ // A more sophisticated factory, which returns a status to indicate
+ // the result of the operation. The factory result is populated into
+ // the user provided pointer result.
+ Status ArgumentFactory(T** result) TF_ATTRIBUTE_NOINLINE {
+ *result = value_;
+ return Status::OK();
+ }
+
+ Status ArgumentFactoryFail(T** result) TF_ATTRIBUTE_NOINLINE {
+ *result = NULL;
+ return Status(tensorflow::error::CANCELLED, "");
+ }
+
+ Status ArgumentFactoryFailShortMsg(T** result) TF_ATTRIBUTE_NOINLINE {
+ *result = NULL;
+ return Status(::tensorflow::error::INTERNAL, "");
+ }
+
+ Status ArgumentFactoryFailLongMsg(T** result) TF_ATTRIBUTE_NOINLINE {
+ *result = NULL;
+ return Status(::tensorflow::error::INTERNAL,
+ "a big string of message junk that will never be read");
+ }
+
+ // A factory that returns a StatusOr<T*>. If the factory operation
+ // is OK, then the StatusOr<T*> will hold a T*. Otherwise, it will
+ // hold a status explaining the error.
+ StatusOr<T*> StatusOrFactory() TF_ATTRIBUTE_NOINLINE {
+ return static_cast<T*>(value_);
+ }
+
+ StatusOr<T*> StatusOrFactoryFail() TF_ATTRIBUTE_NOINLINE {
+ return Status(tensorflow::error::CANCELLED, "");
+ }
+
+ StatusOr<T*> StatusOrFactoryFailShortMsg() TF_ATTRIBUTE_NOINLINE {
+ return Status(::tensorflow::error::INTERNAL, "");
+ }
+
+ StatusOr<T*> StatusOrFactoryFailLongMsg() TF_ATTRIBUTE_NOINLINE {
+ return Status(::tensorflow::error::INTERNAL,
+ "a big string of message junk that will never be read");
+ }
+
+ private:
+ T* volatile value_;
+ TF_DISALLOW_COPY_AND_ASSIGN(BenchmarkFactory);
+};
+
+// A simple type we use with the factory.
+class BenchmarkType {
+ public:
+ BenchmarkType() {}
+ virtual ~BenchmarkType() {}
+ virtual void DoWork() TF_ATTRIBUTE_NOINLINE {}
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(BenchmarkType);
+};
+
+// Calibrate the amount of time spent just calling DoWork, since each of our
+// tests will do this, we can subtract this out of benchmark results.
+static void BM_CalibrateWorkLoop(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ BenchmarkType* result = factory.TrivialFactory();
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ if (result != NULL) result->DoWork();
+ }
+}
+BENCHMARK(BM_CalibrateWorkLoop);
+
+// Measure the time taken to call into the factory, return the value,
+// determine that it is OK, and invoke a trivial function.
+static void BM_TrivialFactory(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ BenchmarkType* result = factory.TrivialFactory();
+ if (result != NULL) result->DoWork();
+ }
+}
+BENCHMARK(BM_TrivialFactory);
+
+// Measure the time taken to call into the factory, providing an
+// out-param for the result, evaluating the status result and the
+// result pointer, and invoking the trivial function.
+static void BM_ArgumentFactory(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ BenchmarkType* result = NULL;
+ Status status = factory.ArgumentFactory(&result);
+ if (status.ok() && result != NULL) {
+ result->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_ArgumentFactory);
+
+// Measure the time to use the StatusOr<T*> factory, evaluate the result,
+// and invoke the trivial function.
+static void BM_StatusOrFactory(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
+ if (result.ok()) {
+ result.ValueOrDie()->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_StatusOrFactory);
+
+// Measure the time taken to call into the factory, providing an
+// out-param for the result, evaluating the status result and the
+// result pointer, and invoking the trivial function.
+static void BM_ArgumentFactoryFail(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ BenchmarkType* result = NULL;
+ Status status = factory.ArgumentFactoryFail(&result);
+ if (status.ok() && result != NULL) {
+ result->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_ArgumentFactoryFail);
+
+// Measure the time to use the StatusOr<T*> factory, evaluate the result,
+// and invoke the trivial function.
+static void BM_StatusOrFactoryFail(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
+ if (result.ok()) {
+ result.ValueOrDie()->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_StatusOrFactoryFail);
+
+// Measure the time taken to call into the factory, providing an
+// out-param for the result, evaluating the status result and the
+// result pointer, and invoking the trivial function.
+static void BM_ArgumentFactoryFailShortMsg(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ BenchmarkType* result = NULL;
+ Status status = factory.ArgumentFactoryFailShortMsg(&result);
+ if (status.ok() && result != NULL) {
+ result->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_ArgumentFactoryFailShortMsg);
+
+// Measure the time to use the StatusOr<T*> factory, evaluate the result,
+// and invoke the trivial function.
+static void BM_StatusOrFactoryFailShortMsg(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
+ if (result.ok()) {
+ result.ValueOrDie()->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_StatusOrFactoryFailShortMsg);
+
+// Measure the time taken to call into the factory, providing an
+// out-param for the result, evaluating the status result and the
+// result pointer, and invoking the trivial function.
+static void BM_ArgumentFactoryFailLongMsg(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ BenchmarkType* result = NULL;
+ Status status = factory.ArgumentFactoryFailLongMsg(&result);
+ if (status.ok() && result != NULL) {
+ result->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_ArgumentFactoryFailLongMsg);
+
+// Measure the time to use the StatusOr<T*> factory, evaluate the result,
+// and invoke the trivial function.
+static void BM_StatusOrFactoryFailLongMsg(int iters) {
+ tensorflow::testing::StopTiming();
+ BenchmarkFactory<BenchmarkType> factory;
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i != iters; ++i) {
+ StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
+ if (result.ok()) {
+ result.ValueOrDie()->DoWork();
+ }
+ }
+}
+BENCHMARK(BM_StatusOrFactoryFailLongMsg);
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/test_helpers.cc b/tensorflow/compiler/xla/test_helpers.cc
new file mode 100644
index 0000000000..02abfdeab8
--- /dev/null
+++ b/tensorflow/compiler/xla/test_helpers.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace xla {
+namespace testing {
+
+AssertionResult::AssertionResult(const AssertionResult& other)
+ : success_(other.success_),
+ message_(other.message_ != nullptr ? new std::string(*other.message_)
+ : static_cast<std::string*>(nullptr)) {
+}
+
+// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE.
+AssertionResult AssertionResult::operator!() const {
+ AssertionResult negation(!success_);
+ if (message_ != nullptr) negation << *message_;
+ return negation;
+}
+
+AssertionResult& AssertionResult::operator=(const AssertionResult& ar) {
+ success_ = ar.success_;
+ message_.reset(ar.message_ != nullptr ? new std::string(*ar.message_)
+ : nullptr);
+ return *this;
+}
+
+AssertionResult AssertionFailure() { return AssertionResult(false); }
+
+AssertionResult AssertionSuccess() { return AssertionResult(true); }
+
+std::function<bool(tensorflow::StringPiece)> ContainsRegex(
+ const tensorflow::StringPiece regex) {
+ return [regex](const tensorflow::StringPiece to_test) {
+ if (RE2::PartialMatch(
+ tensorflow::RegexpStringPiece(to_test.data(), to_test.size()),
+ tensorflow::RegexpStringPiece(regex.data(), regex.size()))) {
+ return true;
+ } else {
+ LOG(ERROR) << "Expected to find " << regex << " in " << to_test;
+ return false;
+ }
+ };
+}
+
+std::function<bool(tensorflow::StringPiece)> HasSubstr(
+ const tensorflow::StringPiece part) {
+ return [part](const tensorflow::StringPiece whole) {
+ return whole.contains(part);
+ };
+}
+
+} // namespace testing
+} // namespace xla
diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h
new file mode 100644
index 0000000000..f923d9f36c
--- /dev/null
+++ b/tensorflow/compiler/xla/test_helpers.h
@@ -0,0 +1,355 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_
+#define TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_
+
+#include <list>
+#include <vector>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/test.h"
+
+// This module contains a minimal subset of gmock functionality just
+// sufficient to execute the currently existing tests.
+namespace util {
+class Status;
+} // namespace util
+
+namespace xla {
+template <typename T>
+class Array2D;
+class Literal;
+
+namespace testing {
+
+class AssertionResult {
+ public:
+ explicit AssertionResult(bool success) : success_(success) {}
+
+ // Returns true iff the assertion succeeded.
+ operator bool() const { return success_; } // NOLINT
+
+ // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE.
+ AssertionResult operator!() const;
+
+ // Returns the text streamed into this AssertionResult. Test assertions
+ // use it when they fail (i.e., the predicate's outcome doesn't match the
+ // assertion's expectation). When nothing has been streamed into the
+ // object, returns an empty string.
+ const char* message() const {
+ return message_ != nullptr ? message_->c_str() : "";
+ }
+
+ // Streams a custom failure message into this object.
+ template <typename T>
+ AssertionResult& operator<<(const T& value) {
+ AppendMessage(::testing::Message() << value);
+ return *this;
+ }
+
+ // Allows streaming basic output manipulators such as endl or flush into
+ // this object.
+ AssertionResult& operator<<(
+ std::ostream& (*basic_manipulator)(std::ostream& stream)) {
+ AppendMessage(::testing::Message() << basic_manipulator);
+ return *this;
+ }
+
+ // Copy operator.
+ AssertionResult(const AssertionResult& ar);
+
+ // Assignment operator.
+ AssertionResult& operator=(const AssertionResult&);
+
+ private:
+ // Appends the contents of message to message_.
+ void AppendMessage(const ::testing::Message& a_message) {
+ if (message_ == nullptr) message_.reset(new std::string);
+ message_->append(a_message.GetString().c_str());
+ }
+
+ bool success_ = false;
+
+ // Stores the message describing the condition in case the
+ // expectation construct is not satisfied with the predicate's
+ // outcome. Referenced via a pointer to avoid taking too much stack
+ // frame space with test assertions.
+ std::unique_ptr<std::string> message_;
+};
+
+AssertionResult AssertionFailure();
+
+AssertionResult AssertionSuccess();
+
+std::function<bool(tensorflow::StringPiece)> ContainsRegex(
+ const tensorflow::StringPiece regex);
+
+std::function<bool(tensorflow::StringPiece)> HasSubstr(
+ const tensorflow::StringPiece part);
+
+// Matcher for a vector of same-type values for which operator= is
+// defined.
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> VectorMatcher(
+ const std::vector<T>& expected) {
+ return [expected](const std::vector<T>& actual) -> AssertionResult {
+ int len = expected.size();
+ if (actual.size() != len) {
+ return AssertionFailure() << "Actual values len of " << actual.size()
+ << " != expected.size " << len;
+ }
+ for (int i = 0; i < len; ++i) {
+ if (actual[i] != expected[i]) {
+ return AssertionFailure() << "Element " << i << " actual " << actual[i]
+ << " != " << expected[i];
+ }
+ }
+ return AssertionSuccess();
+ };
+}
+
+// Approximate matcher for a vector of floats or similar.
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)>
+ApproxVectorMatcher(const std::vector<T>& expected, float abs_diff,
+ float rel_diff) {
+ return [abs_diff, rel_diff,
+ expected](const std::vector<T>& actual) -> AssertionResult {
+ int len = expected.size();
+ if (actual.size() != len) {
+ AssertionResult ar = AssertionFailure() << "Actual values len of "
+ << actual.size()
+ << " != expected.size " << len;
+ LOG(ERROR) << ar.message();
+ return ar;
+ }
+ for (int i = 0; i < len; ++i) {
+ T diff = actual[i] - expected[i];
+ if (diff < 0) {
+ diff *= -1;
+ }
+ if (diff > abs_diff) {
+ T rdiff = (expected[i] != 0 ? diff / expected[i] : 0.0 * expected[i]);
+ if (rdiff > rel_diff) {
+ AssertionResult ar = AssertionFailure()
+ << "Element " << i << " actual " << actual[i]
+ << " != " << expected[i]
+ << "( abs_diff = " << diff
+ << ", rel_diff = " << rdiff << ")";
+ LOG(ERROR) << ar.message();
+ return ar;
+ }
+ }
+ }
+ return AssertionSuccess();
+ };
+}
+
+// Matches a vector of same-type values against another, succeeding so
+// long as they have the same length and every value in 'actual'
+// matches one in 'expected.' Does not verify an exhaustive
+// one-to-one mapping between the two.
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)>
+UnorderedElementsAre(const std::vector<T>& expected) {
+ return [expected](const std::vector<T>& actual) -> AssertionResult {
+ if (actual.size() != expected.size()) {
+ return AssertionFailure() << "sizes don't match";
+ }
+ for (auto a : actual) {
+ bool found = false;
+ for (auto e : expected) {
+ if (a == e) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ return AssertionFailure() << "actual element " << a
+ << " not in expected";
+ }
+ }
+ return AssertionSuccess();
+ };
+}
+
+// Overloaded cover functions for UnorderedElementsAre, for the numbers
+// of values used in practice.
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher(
+ T a) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ return testing::UnorderedElementsAre<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher(
+ T a, T b) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ return testing::UnorderedElementsAre<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher(
+ T a, T b, T c) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ expected.push_back(c);
+ return testing::UnorderedElementsAre<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher(
+ T a, T b, T c, T d) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ expected.push_back(c);
+ expected.push_back(d);
+ return testing::UnorderedElementsAre<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher(
+ T a, T b, T c, T d, T e) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ expected.push_back(c);
+ expected.push_back(d);
+ expected.push_back(e);
+ return testing::UnorderedElementsAre<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> UnorderedMatcher(
+ T a, T b, T c, T d, T e, T f) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ expected.push_back(c);
+ expected.push_back(d);
+ expected.push_back(e);
+ expected.push_back(f);
+ return testing::UnorderedElementsAre<T>(expected);
+}
+
+// Overloaded cover functions for VectorMatcher for the numbers of
+// elements used in practice.
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher(
+ T a) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ return testing::VectorMatcher<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher(
+ T a, T b) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ return testing::VectorMatcher<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher(
+ T a, T b, T c) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ expected.push_back(c);
+ return testing::VectorMatcher<T>(expected);
+}
+
+template <typename T>
+std::function<AssertionResult(const std::vector<T>& actual)> OrderedMatcher(
+ T a, T b, T c, T d) {
+ std::vector<T> expected;
+ expected.push_back(a);
+ expected.push_back(b);
+ expected.push_back(c);
+ expected.push_back(d);
+ return testing::VectorMatcher<T>(expected);
+}
+
+// Convert a RepeatedField to a flat vector.
+template <typename T>
+std::vector<T> PBToVec(const tensorflow::protobuf::RepeatedField<T> rf) {
+ return std::vector<T>(rf.begin(), rf.end());
+}
+
+// Convert a List to a flat vector.
+template <typename T>
+std::vector<T> ListToVec(const std::list<T>& l) {
+ return std::vector<T>(l.begin(), l.end());
+}
+
+// Convert a Set to a flat vector.
+template <typename T>
+std::vector<T> SetToVec(const std::set<T>& c) {
+ return std::vector<T>(c.begin(), c.end());
+}
+
+// Convert an Array to a flat vector.
+template <typename T>
+std::vector<T> Array2DToVec(const Array2D<T>& a) {
+ return std::vector<T>(a.data(), a.data() + a.num_elements());
+}
+
+namespace internal_status {
+inline const ::tensorflow::Status& GetStatus(
+ const ::tensorflow::Status& status) {
+ return status;
+}
+
+template <typename T>
+inline const ::tensorflow::Status& GetStatus(const StatusOr<T>& status) {
+ return status.status();
+}
+} // namespace internal_status
+
+} // namespace testing
+} // namespace xla
+
+// The following macros are similar to macros in gmock, but deliberately named
+// differently in order to avoid conflicts in files which include both.
+
+// Macros for testing the results of functions that return tensorflow::Status or
+// StatusOr<T> (for any type T).
+#define EXPECT_IS_OK(expression) \
+ EXPECT_EQ(tensorflow::Status::OK(), \
+ xla::testing::internal_status::GetStatus(expression))
+#undef ASSERT_IS_OK
+#define ASSERT_IS_OK(expression) \
+ ASSERT_EQ(tensorflow::Status::OK(), \
+ xla::testing::internal_status::GetStatus(expression))
+
+// Macros that apply a Matcher to a Value, returning an
+// AssertionResult which gets digested by a standard gunit macro.
+#define EXPECT_MATCH(V, M) EXPECT_TRUE((M)((V)))
+#define ASSERT_MATCH(V, M) ASSERT_TRUE(M(V))
+
+#endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
new file mode 100644
index 0000000000..93fe1fee4a
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -0,0 +1,1436 @@
+# Description:
+# Base testing infrastructure for XLA.
+
+licenses(["notice"]) # Apache 2.0
+
+package(
+ default_visibility = [":friends"],
+ features = ["no_layering_check"],
+)
+
+package_group(
+ name = "friends",
+ includes = [
+ "//tensorflow/compiler/xla:friends",
+ ],
+)
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
+load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
+
+# Generate test_suites for all backends, named "${backend}_tests".
+generate_backend_suites()
+
+cc_library(
+ name = "test_macros_header",
+ testonly = True,
+ hdrs = ["test_macros.h"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:test",
+ ],
+)
+
+# Generate a test_macros_${BACKEND} library per backend with the proper copts.
+generate_backend_test_macros()
+
+cc_library(
+ name = "test_utils",
+ testonly = True,
+ hdrs = ["test_utils.h"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "literal_test_util",
+ testonly = True,
+ srcs = ["literal_test_util.cc"],
+ hdrs = ["literal_test_util.h"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "hlo_test_base",
+ testonly = True,
+ srcs = ["hlo_test_base.cc"],
+ hdrs = ["hlo_test_base.h"],
+ deps = [
+ ":literal_test_util",
+ "//tensorflow/compiler/xla:shape_layout",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_test_base_flags",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:computation_layout",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_execution_profile",
+ "//tensorflow/compiler/xla/service:hlo_graph_dumper",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ "//third_party/eigen3",
+ ],
+)
+
+cc_binary(
+ name = "local_client_aot_test_helper",
+ srcs = ["local_client_aot_test_helper.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+genrule(
+ name = "local_client_aot_test_computation",
+ outs = ["local_client_aot_test_computation.o"],
+ cmd = "$(location :local_client_aot_test_helper) $(TARGET_CPU) > $(OUTS)",
+ local = 1,
+ tools = [":local_client_aot_test_helper"],
+)
+
+cc_library(
+ name = "client_library_test_base",
+ testonly = True,
+ srcs = ["client_library_test_base.cc"],
+ hdrs = ["client_library_test_base.h"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:hlo_pass_pipeline_flags",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "codegen_test_base",
+ testonly = True,
+ srcs = ["codegen_test_base.cc"],
+ hdrs = ["codegen_test_base.h"],
+ data = [
+ "@llvm//:FileCheck",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/service:backend",
+ "//tensorflow/compiler/xla/service:compiler",
+ "//tensorflow/compiler/xla/service:executable",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_module_config",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_library(
+ name = "local_client_test_base",
+ testonly = True,
+ srcs = ["local_client_test_base.cc"],
+ hdrs = ["local_client_test_base.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ ],
+)
+
+xla_test(
+ name = "bad_rng_shape_validation_test",
+ srcs = ["bad_rng_shape_validation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "check_execution_arity_test",
+ srcs = ["check_execution_arity_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "query_inferred_shape_test",
+ srcs = ["query_inferred_shape_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "while_test",
+ srcs = ["while_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "axpy_simple_test",
+ srcs = ["axpy_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "map_test",
+ srcs = ["map_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "params_test",
+ srcs = ["params_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "pred_test",
+ srcs = ["pred_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "select_test",
+ srcs = ["select_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "unary_op_test",
+ srcs = ["unary_op_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "scalar_computations_test",
+ srcs = ["scalar_computations_test.cc"],
+ shard_count = 16,
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "deallocation_test",
+ srcs = ["deallocation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "deconstruct_tuple_test",
+ srcs = ["deconstruct_tuple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "array_elementwise_ops_test",
+ srcs = ["array_elementwise_ops_test.cc"],
+ # This test includes comparisons to NAN, so disable fast-math.
+ backend_args = {
+ "cpu": ["--xla_fast_math=false"],
+ "cpu_parallel": ["--xla_fast_math=false"],
+ "gpu": ["--xla_fast_math=false"],
+ },
+ shard_count = 25,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:llvm_backend_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "dot_operation_test",
+ srcs = ["dot_operation_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+# Tests the dot operation in some cases that can be performed via a
+# runtime call on some backends - e.g. a runtime call to to Eigen.
+xla_test(
+ name = "dot_operation_runtime_test",
+ srcs = ["dot_operation_test.cc"],
+ backend_args = {
+ "cpu": ["--xla_cpu_use_eigen"],
+ "cpu_parallel": ["--xla_cpu_use_eigen"],
+ },
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+# Repeat dot_operation_runtime_test with single-threded eigen.
+xla_test(
+ name = "dot_operation_single_threaded_runtime_test",
+ srcs = ["dot_operation_test.cc"],
+ backend_args = {
+ "cpu": [
+ "--xla_cpu_use_eigen",
+ "--xla_cpu_multi_thread_eigen=false",
+ ],
+ "cpu_parallel": [
+ "--xla_cpu_use_eigen",
+ "--xla_cpu_multi_thread_eigen=false",
+ ],
+ },
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "dot_operation_rowmajor_runtime_test",
+ srcs = ["dot_operation_test.cc"],
+ backend_args = {
+ "cpu": [
+ "--xla_cpu_use_eigen",
+ "--xla_default_layout=major2minor",
+ ],
+ "cpu_parallel": [
+ "--xla_cpu_use_eigen",
+ "--xla_default_layout=major2minor",
+ ],
+ },
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags",
+ "//tensorflow/compiler/xla/legacy_flags:layout_util_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "transpose_test",
+ srcs = ["transpose_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "constants_test",
+ srcs = ["constants_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convolution_test",
+ timeout = "long",
+ srcs = ["convolution_test.cc"],
+ shard_count = 25,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convolution_variants_test",
+ timeout = "long",
+ srcs = ["convolution_variants_test.cc"],
+ backend_tags = {
+ # TODO(b/31436974): Fix msan failure. Failed on 2016-09-12.
+ "cpu": ["nomsan"],
+ "cpu_parallel": ["nomsan"],
+ },
+ shard_count = 30,
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convolution_dimension_numbers_test",
+ timeout = "long",
+ srcs = ["convolution_dimension_numbers_test.cc"],
+ shard_count = 20,
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "batch_normalization_test",
+ srcs = ["batch_normalization_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "slice_test",
+ srcs = ["slice_test.cc"],
+ shard_count = 40,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "multidimensional_slice_test",
+ srcs = ["multidimensional_slice_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "dynamic_ops_test",
+ srcs = ["dynamic_ops_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:device_memory_allocator",
+ "//tensorflow/compiler/xla/service:local_service",
+ "//tensorflow/compiler/xla/service:platform_util",
+ "//tensorflow/compiler/xla/service:shaped_buffer",
+ "//tensorflow/compiler/xla/service:transfer_manager",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "tuple_test",
+ srcs = ["tuple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "vector_ops_reduce_test",
+ srcs = ["vector_ops_reduce_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reduce_test",
+ srcs = ["reduce_test.cc"],
+ shard_count = 40,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reduce_window_test",
+ timeout = "long",
+ srcs = ["reduce_window_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "select_and_scatter_test",
+ timeout = "long",
+ srcs = ["select_and_scatter_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "copy_test",
+ srcs = ["copy_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "call_test",
+ srcs = ["call_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "custom_call_test",
+ srcs = ["custom_call_test.cc"],
+ linkopts = export_dynamic_linkopts,
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "binop_scaling_test",
+ srcs = ["binop_scaling_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "broadcast_simple_test",
+ srcs = ["broadcast_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "pad_test",
+ srcs = ["pad_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "fmax_test",
+ srcs = ["fmax_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "log_test",
+ srcs = ["log_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "matrix_ops_simple_test",
+ srcs = ["matrix_ops_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "prng_test",
+ srcs = ["prng_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reshape_test",
+ srcs = ["reshape_test.cc"],
+ shard_count = 30,
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reverse_test",
+ srcs = ["reverse_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "vector_ops_simple_test",
+ srcs = ["vector_ops_simple_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "concat_test",
+ srcs = ["concat_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "convert_test",
+ srcs = ["convert_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:stream_executor_no_cuda",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "compilation_cache_test",
+ srcs = ["compilation_cache_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla:xla_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "floor_ceil_test",
+ srcs = ["floor_ceil_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "compute_constant_test",
+ srcs = ["compute_constant_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "client_test",
+ srcs = ["client_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:test_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "inprocess_service_test",
+ srcs = ["inprocess_service_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "replay_test",
+ srcs = ["replay_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:protobuf_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "broadcast_test",
+ srcs = ["broadcast_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "round_trip_packed_literal_test",
+ srcs = ["round_trip_packed_literal_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:packed_literal_reader",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "fusion_test",
+ srcs = ["fusion_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_test(
+ name = "local_client_aot_test",
+ srcs = [
+ "local_client_aot_test.cc",
+ ":local_client_aot_test_computation.o",
+ ],
+ deps = [
+ "//tensorflow/compiler/xla:executable_run_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+xla_test(
+ name = "round_trip_transfer_test",
+ srcs = ["round_trip_transfer_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "set_return_value_test",
+ srcs = ["set_return_value_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+xla_test(
+ name = "reshape_motion_test",
+ srcs = ["reshape_motion_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array4d",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:reference_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+cc_test(
+ name = "literal_test_util_test",
+ srcs = ["literal_test_util_test.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
new file mode 100644
index 0000000000..cf6f9a825c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -0,0 +1,1662 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ArrayElementwiseOpTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+class ArrayElementwiseOpTestParamCount
+ : public ArrayElementwiseOpTest,
+ public ::testing::WithParamInterface<int> {};
+
+XLA_TEST_F(ArrayElementwiseOpTest, NegConstantZeroElementF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto result = builder.Neg(a);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, NegConstantF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto result = builder.Neg(a);
+
+ ComputeAndCompareR1<float>(&builder, {2.5f, -3.14f, -2.25f, 10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, NegConstantS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-1, 0, 1, 324,
+ std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max()});
+ auto result = builder.Neg(a);
+
+ // -min == min for int32 due to an overflow. In C++ it is undefined behavior
+ // to do this calculation. For XLA we have not specified that, so it
+ // ought to work.
+ ComputeAndCompareR1<int32>(&builder,
+ {1, 0, -1, -324, std::numeric_limits<int32>::min(),
+ -std::numeric_limits<int32>::max()},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, AddTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {97.5f, 6.27f, 5.0f, 0.5f, -993.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) {
+ const int count = GetParam();
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> a_values;
+ std::vector<float> b_values;
+ for (int i = 0; i < count; ++i) {
+ a_values.push_back(i / static_cast<float>(count));
+ b_values.push_back(2 * i / static_cast<float>(count + 2));
+ }
+
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR1<float>({a_values});
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ auto a_constant = builder.ConstantR1<float>(a_values);
+ auto a_param = builder.Parameter(0, a_literal->shape(), "a_param");
+
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR1<float>({b_values});
+ std::unique_ptr<GlobalData> b_data =
+ client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param");
+ auto b_param = builder.ConstantR1<float>(b_values);
+
+ auto sum1 = builder.Add(a_constant, b_constant);
+ auto sum2 = builder.Add(a_constant, b_param);
+ auto sum3 = builder.Add(a_param, b_constant);
+ auto sum4 = builder.Add(a_param, b_param);
+
+ auto sum = builder.Add(sum1, sum2);
+ sum = builder.Add(sum, sum3);
+ sum = builder.Add(sum, sum4);
+
+ std::vector<float> expected;
+ for (int64 i = 0; i < count; ++i) {
+ expected.push_back(4 * (a_values[i] + b_values[i]));
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {a_data.get(), b_data.get()},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, SubTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({100.0f, 3.13f, 2.75f, 10.5f, -999.0f});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {-102.5f, 0.01f, -0.5f, -20.5f, 1005.0f},
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, SubTwoConstantS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-1, 0, 2, 1000000000});
+ auto b = builder.ConstantR1<int32>({-1, 2, 1, -1});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {0, -2, 1, 1000000001}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto b = builder.ConstantR1<int32>({});
+ auto add = builder.Sub(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, DivTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({10.0f, 5.1f, 1.0f, 10.0f, -6.0f});
+ auto add = builder.Div(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {-0.25f, 5.0f, 2.25f, -1.0f, -1.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, DivTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Div(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>(
+ {-2.5f, 25.5f, 2.25f, -10.0f, 6.0f, 3.0f, 3.0f, -1.0f, -8.0f});
+ auto b = builder.ConstantR1<float>(
+ {10.0f, 5.1f, 1.0f, 10.0f, -6.0f, 2.0f, -2.0f, 7.0f, -4.0f});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<float>(
+ &builder, {-2.5f, 0.0f, 0.25f, 0.0f, -0.0f, 1.0f, 1.0f, -1.0f, -0.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemF64s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<double>(
+ {-2.5, 25.5, 2.25, -10.0, 6.0, 3.0, 3.0, -1.0, -8.0});
+ auto b = builder.ConstantR1<double>(
+ {10.0, 5.1, 1.0, 10.0, -6.0, 2.0, -2.0, 7.0, -4.0});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<double>(
+ &builder, {-2.5, 0.0, 0.25, 0.0, -0.0, 1.0, 1.0, -1.0, -0.0}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, MulTwoConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto b = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {-25.0f, 127.5f, 2.25f, -100.0f, -36.0f},
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, MulTwoConstantS32s) {
+ std::vector<int32> data = {0,
+ 1,
+ -1,
+ 1234,
+ 0x1a243514,
+ std::numeric_limits<int32>::max(),
+ std::numeric_limits<int32>::min()};
+ // Form the test data set using all products of 'data' with itself.
+ std::vector<int32> a_data, b_data, expected;
+ for (int32 a : data) {
+ for (int32 b : data) {
+ a_data.push_back(a);
+ b_data.push_back(b);
+ expected.push_back(static_cast<uint32>(a) * static_cast<uint32>(b));
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>(a_data);
+ auto b = builder.ConstantR1<int32>(b_data);
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ auto b = builder.ConstantR1<int32>({});
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) {
+ std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
+ 0x1a243514, 0xFFFFFFFF, 0x80808080};
+
+ // Form the test data set using all products of 'data' with itself.
+ std::vector<uint32> a_data, b_data, expected;
+ for (uint32 a : data) {
+ for (uint32 b : data) {
+ a_data.push_back(a);
+ b_data.push_back(b);
+ expected.push_back(a * b);
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint32>(a_data);
+ auto b = builder.ConstantR1<uint32>(b_data);
+ auto add = builder.Mul(a, b);
+
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, LogicalAnd) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({false, false, true, true});
+ auto b = builder.ConstantR1<bool>({false, true, false, true});
+ auto out = builder.LogicalAnd(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {false, false, false, true}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({});
+ auto b = builder.ConstantR1<bool>({});
+ auto out = builder.LogicalAnd(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, LogicalOr) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({false, false, true, true});
+ auto b = builder.ConstantR1<bool>({false, true, false, true});
+ auto out = builder.LogicalOr(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, true}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({});
+ auto b = builder.ConstantR1<bool>({});
+ auto out = builder.LogicalOr(a, b);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, LogicalNot) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({false, true, true, false});
+ auto out = builder.LogicalNot(a);
+
+ ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({});
+ auto out = builder.LogicalNot(a);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareEqF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 2.25f, 10.0f, NAN});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, false, true, false, false}, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGeF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Ge(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGtF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Gt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {false, true, true, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLeF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 5.0f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Le(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {true, true, false, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLtF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, NAN});
+ auto compare = builder.Lt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {true, false, false, false, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareEqS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, false, true, false, false, false, true},
+ {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareEqZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({});
+ auto rhs = builder.ConstantR1<int32>({});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(&builder, {}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareNeS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Ne(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, true, false, true, true, true, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGeS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Ge(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, true, true, false, true, true, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGtS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Gt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, false, false, true, false, false, true, true, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLeS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Le(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, true, true, false, true, true, false, false, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLtS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<int32>({min, min, min, 0, 0, 0, max, max, max});
+ auto rhs = builder.ConstantR1<int32>({min, 0, max, -1, 0, 1, min, 0, max});
+ auto compare = builder.Lt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, false, false, true, false, false, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareEqU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Eq(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, false, true, false, false, false, true},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareNeU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Ne(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, true, false, true, true, true, false}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGeU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Ge(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, false, false, true, true, false, true, true, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareGtU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Gt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, false, false, true, false, false, true, true, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLeU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Le(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {true, true, true, false, true, true, false, false, true}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, CompareLtU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<uint32>({0, 0, 0, 5, 5, 5, max, max, max});
+ auto rhs = builder.ConstantR1<uint32>({0, 1, max, 4, 5, 6, 0, 1, max});
+ auto compare = builder.Lt(lhs, rhs);
+
+ ComputeAndCompareR1<bool>(
+ &builder, {false, true, true, false, false, true, false, false, false},
+ {});
+}
+
+TEST_F(ArrayElementwiseOpTest, PowF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({4.0f, 2.0f, 2.0f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -2.0f, 3.0f, 10.0f, NAN});
+ auto minimum = builder.Pow(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder, {16.0f, 0.25f, 8.0f, NAN, NAN}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto minimum = builder.Pow(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+// Some Pow cases that can be implemented more efficiently.
+TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
+ ComputationBuilder b(client_, TestName());
+
+ std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
+ std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
+
+ std::unique_ptr<Literal> param_literal = LiteralUtil::CreateR1<float>(values);
+ std::unique_ptr<GlobalData> param_data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+
+ auto sum = b.ConstantR0<float>(0.0f);
+ auto param = b.Parameter(0, param_literal->shape(), "param");
+ for (float exponent : exponents) {
+ sum = b.Add(sum, b.Pow(param, b.ConstantR0<float>(exponent)));
+ }
+
+ std::vector<float> expected;
+ for (auto value : values) {
+ float sum = 0.0f;
+ for (float exponent : exponents) {
+ sum += std::pow(value, exponent);
+ }
+ expected.push_back(sum);
+ }
+
+ ComputeAndCompareR1<float>(&b, expected, {param_data.get()}, error_spec_);
+}
+
+TEST_P(ArrayElementwiseOpTestParamCount, SquareManyValues) {
+ const int count = GetParam();
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> values;
+ for (int i = 0; i < count; ++i) {
+ values.push_back(i / static_cast<float>(count));
+ }
+ auto x = builder.ConstantR1<float>(values);
+ auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
+
+ std::vector<float> expected;
+ for (float value : values) {
+ expected.push_back(value * value);
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, SquareIn4D) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> values(2, 2, 2, 2);
+
+ std::vector<float> values_vector;
+ std::vector<float> expected_vector;
+ for (int i = 0; i < values.num_elements(); ++i) {
+ values_vector.push_back(static_cast<float>(i) / values.num_elements());
+ expected_vector.push_back(values_vector.back() * values_vector.back());
+ }
+ values.SetValues(values_vector);
+
+ Array4D<float> expected(2, 2, 2, 2, expected_vector);
+
+ auto x = builder.ConstantR4FromArray4D<float>(values);
+ auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, SquareIn4DZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> values(2, 2, 0, 2);
+ Array4D<float> expected(2, 2, 0, 2);
+
+ auto x = builder.ConstantR4FromArray4D<float>(values);
+ auto exp = builder.Pow(x, builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+// GPU backend emits nvvm intrinsic for fmin and fmax, whose semantics is NOT
+// such
+// * fmin(NaN, x) = x
+// * fmax(NaN, x) = x
+// so we only test NAN on CPU.
+//
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends.
+TEST_F(ArrayElementwiseOpTest, MinF32s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
+#else
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
+#endif
+ auto minimum = builder.Min(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {1.0f, -5.0f, 1.0f},
+#else
+ {1.0f, -5.0f, 1.0f, 10.0f, 6.0f},
+#endif
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MinZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto minimum = builder.Min(lhs, rhs);
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends. See comment on MinF32s test above.
+XLA_TEST_F(ArrayElementwiseOpTest, MinF64s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
+#else
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
+#endif
+ auto minimum = builder.Min(lhs, rhs);
+
+ ComputeAndCompareR1<double>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {1.0, -5.0, 1.0},
+#else
+ {1.0, -5.0, 1.0, 10.0, 6.0},
+#endif
+ {}, error_spec_);
+}
+
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends. See comment on MinF32s test above.
+TEST_F(ArrayElementwiseOpTest, MaxF32s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f});
+#else
+ auto lhs = builder.ConstantR1<float>({1.0f, 1.0f, 2.25f, NAN, 6.0f});
+ auto rhs = builder.ConstantR1<float>({2.0f, -5.0f, 1.0f, 10.0f, NAN});
+#endif
+ auto maximum = builder.Max(lhs, rhs);
+
+ ComputeAndCompareR1<float>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {2.0f, 1.0f, 2.25f},
+#else
+ {2.0f, 1.0f, 2.25f, 10.0f, 6.0f},
+#endif
+ {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto minimum = builder.Max(lhs, rhs);
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+// TODO(b/28180546): Make this compile in a way that is consistent
+// among backends. See comment on MinF32s test above.
+XLA_TEST_F(ArrayElementwiseOpTest, MaxF64s) {
+ ComputationBuilder builder(client_, TestName());
+#if !defined(XLA_TEST_BACKEND_CPU)
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0});
+#else
+ auto lhs = builder.ConstantR1<double>({1.0, 1.0, 2.25, NAN, 6.0});
+ auto rhs = builder.ConstantR1<double>({2.0, -5.0, 1.0, 10.0, NAN});
+#endif
+ auto maximum = builder.Max(lhs, rhs);
+
+ ComputeAndCompareR1<double>(&builder,
+#if !defined(XLA_TEST_BACKEND_CPU)
+ {2.0, 1.0, 2.25},
+#else
+ {2.0, 1.0, 2.25, 10.0, 6.0},
+#endif
+ {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, MaxS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>(
+ {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<int32>(
+ {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
+ builder.Max(x, y);
+
+ std::vector<int32> expected = {min, max, 0, -1, 0, 0, 0,
+ 1, 1, 10, max, max, max};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MinS32s) {
+ const int32 min = std::numeric_limits<int32>::min();
+ const int32 max = std::numeric_limits<int32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>(
+ {min, min, min, -1, -1, 0, 0, 0, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<int32>(
+ {min, max, 0, -10, 0, -1, 0, 1, 0, 10, 0, max, min});
+ builder.Min(x, y);
+
+ std::vector<int32> expected = {min, min, min, -10, -1, -1, 0,
+ 0, 0, 1, 0, max, min};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MaxU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
+ builder.Max(x, y);
+
+ std::vector<uint32> expected = {0, 1, 1, 1, 10, max, max, max};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MinU32s) {
+ const uint32 max = std::numeric_limits<uint32>::max();
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<uint32>({0, 0, 1, 1, 1, max, max, max});
+ auto y = builder.ConstantR1<uint32>({0, 1, 0, 1, 10, 0, 234234, max});
+ builder.Min(x, y);
+
+ std::vector<uint32> expected = {0, 0, 0, 1, 1, 0, 234234, max};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, MaxTenF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
+ auto y = builder.ConstantR1<float>(
+ {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
+ builder.Max(x, y);
+
+ std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S1AndR1S0F32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto u = builder.ConstantR1<float>({3.5});
+ auto v = builder.ConstantR1<float>({});
+ builder.Max(u, v);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxR1S0AndR2S0x2F32s) {
+ for (int broadcast_dim : {0, 1}) {
+ ComputationBuilder builder(client_, TestName());
+ auto u = builder.ConstantR1<float>({3.5});
+ auto v = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ builder.Max(u, v, /*broadcast_dimensions=*/{broadcast_dim});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
+ }
+}
+
+TEST_F(ArrayElementwiseOpTest, Max1DAnd2DF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
+ auto m =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ builder.Max(v, m, /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({{2.0f, 3.14f, 4.0f}, {2.25f, 3.0f, 4.0f}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Max1DAnd2DZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({});
+ auto m = builder.ConstantR2<float>({{}, {}});
+ builder.Max(v, m, /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({{}, {}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto scalar = builder.ConstantR0<int32>(2);
+ Array3D<int32> a_3d({{{3, 9, -1}, {2, -10, 3}}, {{-2, 2, 8}, {12, 10, 4}}});
+ auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
+ builder.Max(array, scalar, /*broadcast_dimensions=*/{});
+
+ Array3D<int32> expected({{{3, 9, 2}, {2, 2, 3}}, {{2, 2, 8}, {12, 10, 4}}});
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Max3DAndScalarZeroElementS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto scalar = builder.ConstantR0<int32>(2);
+ Array3D<int32> a_3d(2, 0, 3);
+ auto array = builder.ConstantR3FromArray3D<int32>(a_3d);
+ builder.Max(array, scalar, /*broadcast_dimensions=*/{});
+
+ Array3D<int32> expected(2, 0, 3);
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, Min2DTo1DF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto m =
+ builder.ConstantR2<float>({{-10.4f, 64.0f, 6.0f}, {0.1f, 32.0f, 16.1f}});
+ auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
+ builder.Min(m, v, /*broadcast_dimensions=*/{0});
+
+ Array2D<float> expected({{-10.4f, -10.2f, -10.2f}, {0.1f, 16.4f, 16.1f}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo1DZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantR2<float>({{}, {}});
+ auto v = builder.ConstantR1<float>({-10.2f, 16.4f});
+ builder.Min(m, v, /*broadcast_dimensions=*/{0});
+
+ Array2D<float> expected({{}, {}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto array2d =
+ builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
+ auto array4d = builder.ConstantR4FromArray4D<float>(
+ {{{{-12.1f, 32.3f, 6.2f}}, {{0.0f, 32.5f, 3.0f}}},
+ {{{-2.5f, 64.29f, 6.5f}}, {{-0.01f, 32.25f, 2.6f}}}});
+ builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
+
+ Array4D<float> expected(
+ {{{{-12.2f, 32.3f, 6.1f}}, {{0.0f, 32.2f, 2.5f}}},
+ {{{-12.2f, 64.29f, 6.1f}}, {{-0.01f, 32.2f, 2.5f}}}});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Min2DTo4DZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto array2d =
+ builder.ConstantR2<float>({{-12.2f, 64.3f, 6.1f}, {0.0f, 32.2f, 2.5f}});
+ Array4D<float> arg(2, 2, 0, 3);
+ auto array4d = builder.ConstantR4FromArray4D<float>(arg);
+ builder.Min(array2d, array4d, /*broadcast_dimensions=*/{1, 3});
+
+ Array4D<float> expected(2, 2, 0, 3);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MinTenS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
+ builder.Min(x, y);
+
+ std::vector<int32> expected = {0, 1, 2, 3, 4, 4, 3, 2, 1, 0};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, MaxTenS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto y = builder.ConstantR1<int32>({9, 8, 7, 6, 5, 4, 3, 2, 1, 0});
+ builder.Max(x, y);
+
+ std::vector<int32> expected = {9, 8, 7, 6, 5, 5, 6, 7, 8, 9};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, RemTwoConstantS32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({-3, 26, 2, -1, 1});
+ auto b = builder.ConstantR1<int32>({10, 5, 1, 10, -10});
+ auto add = builder.Rem(a, b);
+
+ ComputeAndCompareR1<int32>(&builder, {-3, 1, 0, -1, 1}, {});
+}
+
+TEST_F(ArrayElementwiseOpTest, NonNanClampF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto minimum = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
+ auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 10.0f});
+ auto maximum = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
+ auto clamp = builder.Clamp(minimum, argument, maximum);
+
+ ComputeAndCompareR1<float>(&builder, {2.0f, 0.5f, 1.0f, 2.25f, 10.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, ClampF32Scalar) {
+ ComputationBuilder builder(client_, TestName());
+ auto minimum = builder.ConstantR0<float>(0.0f);
+ auto argument = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto maximum = builder.ConstantR0<float>(5.0f);
+ auto clamp = builder.Clamp(minimum, argument, maximum);
+
+ ComputeAndCompareR1<float>(&builder, {2.0f, 5.0f, 0.0f, 1.0f, 4.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) {
+ ComputationBuilder builder(client_, TestName());
+ auto min_scalar = builder.ConstantR0<float>(0.0f);
+ auto min_vector = builder.ConstantR1<float>({1.0f, -6.5f, 1.0f, 2.25f, 0.0f});
+ auto arg_vector = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto arg_scalar = builder.ConstantR1<float>({2.0f, 10.0f, -5.0f, 1.0f, 4.0f});
+ auto max_scalar = builder.ConstantR0<float>(3.0f);
+ auto max_vector = builder.ConstantR1<float>({3.0f, 0.5f, 25.5f, 5.0f, 123.0});
+ // Perform clamp with broadcasted scalar and vector.
+ auto clamp = builder.Add(
+ builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar),
+ builder.Clamp(min_scalar, arg_vector, max_vector)),
+ builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector),
+ builder.Clamp(min_scalar, arg_scalar, max_vector)));
+
+ ComputeAndCompareR1<float>(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({7.2f, 2.3f, 3.4f, 5.6f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto add = builder.Add(p0, p1);
+
+ ComputeAndCompareR1<float>(&builder, {8.3f, 4.5f, 6.7f, 11.1f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(0, 7, 0));
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto p0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto add = builder.Add(p0, p1);
+
+ Array3D<float> expected(0, 7, 0);
+ ComputeAndCompareR3<float>(
+ &builder, expected, {param0_data.get(), param1_data.get()}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({1.1f, 2.2f, 3.3f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
+ auto p = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto add = builder.Add(a, p);
+
+ ComputeAndCompareR1<float>(&builder, {2.2f, 4.4f, 6.6f, 9.9f},
+ {param0_data.get()}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, TanhF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({-2.5f, 3.14f, 2.25f});
+ auto result = builder.Tanh(a);
+
+ ComputeAndCompareR1<float>(&builder, {-0.986614f, 0.996260f, 0.978026}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
+ // a ------ (add) --------- (add)
+ // / /
+ // b -----/ /
+ // c---------------------/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({1.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+
+ auto add = builder.Add(a, b);
+ auto add2 = builder.Add(add, c);
+
+ ComputeAndCompareR1<float>(&builder, {-0.1f, -10.1f, -0.1f, -20.1f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddChainFoldRight) {
+ // b ------ (add) --------- (add)
+ // / /
+ // c -----/ /
+ // a---------------------/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+
+ auto add = builder.Add(b, c);
+ auto add2 = builder.Add(a, add);
+
+ ComputeAndCompareR1<float>(&builder, {89.9f, -10.1f, -0.1f, -20.1f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddWithNeg) {
+ // a ----- (neg) ----- (add)
+ // /
+ // b ----- (neg) ----/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+
+ auto neg_a = builder.Neg(a);
+ auto neg_b = builder.Neg(b);
+ auto result = builder.Add(neg_a, neg_b);
+
+ ComputeAndCompareR1<float>(&builder, {-93.2f, -5.4f, -7.6f, -9.8f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, AddChainTwoSide) {
+ // a ------ (add) ------------\
+ // / \
+ // b -----/ (add)
+ // /
+ // c ------ (add) ------------/
+ // /
+ // d -----/
+ ComputationBuilder builder(client_, TestName());
+
+ auto a = builder.ConstantR1<float>({91.1f, 2.2f, 3.3f, 4.4f});
+ auto b = builder.ConstantR1<float>({2.1f, 3.2f, 4.3f, 5.4f});
+ auto c = builder.ConstantR1<float>({-3.3f, -15.5f, -7.7f, -29.9f});
+ auto d = builder.ConstantR1<float>({-19.0f, 10.0f, -40.0f, 20.2f});
+
+ auto add_ab = builder.Add(a, b);
+ auto add_cd = builder.Add(c, d);
+ auto add_all = builder.Add(add_ab, add_cd);
+
+ ComputeAndCompareR1<float>(&builder, {70.9f, -0.1f, -40.1f, 0.1f}, {},
+ error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, 2DBinaryOpF32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto b =
+ builder.ConstantR2<float>({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}});
+ auto add = builder.Add(a, b);
+
+ Array2D<float> expected_array(
+ {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, ScalarPlus2DF32) {
+ // Add a scalar + matrix.
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto scalar = builder.ConstantR0<float>(3.0f);
+ auto add = builder.Add(scalar, a);
+
+ Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, 2DPlusScalarF32) {
+ // Add a matrix + scalar.
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto scalar = builder.ConstantR0<float>(3.0f);
+ auto add = builder.Add(a, scalar);
+
+ Array2D<float> expected_array({{0.5f, 6.14f, 4.0f}, {5.25f, -7.0f, 6.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
+ // Test simple broadcasting of a R1F32 over R2F32. The vector's size matches
+ // only dim 0 of the matrix.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({20.0f, 40.0f, 60.0f});
+ // clang-format off
+ auto m = builder.ConstantR2<float>({
+ {-2.5f, 3.14f, 1.0f},
+ {2.25f, -10.0f, 3.33f}});
+ // clang-format on
+ auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
+ Array2D<float> expected_array(
+ {{17.5f, 43.14f, 61.0f}, {22.25f, 30.0f, 63.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
+ // Test broadcasting in Eq comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({42, 73});
+ auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
+
+ // This test exercises both possible broadcast dimensions for a vector/matrix
+ // comparison.
+ auto cmp_dim_0 = builder.Eq(v, m, /*broadcast_dimensions=*/{1});
+ auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0});
+ auto result = builder.Tuple({cmp_dim_0, cmp_dim_1});
+
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<bool>({{true, true}, {true, false}}).get(),
+ LiteralUtil::CreateR2<bool>({{true, false}, {false, false}}).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) {
+ // Test broadcasting in Ne comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({42, 73});
+ auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
+ auto cmp = builder.Ne(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,2] {
+ { 00 },
+ { 01 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) {
+ // Test broadcasting in Ge comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Ge(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 1100 },
+ { 0001 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) {
+ // Test broadcasting in Gt comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Gt(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 0100 },
+ { 0000 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) {
+ // Test broadcasting in Le comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Le(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 1011 },
+ { 1111 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) {
+ // Test broadcasting in Lt comparison.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, 2, 3, 4});
+ auto m = builder.ConstantR2<int32>({{1, 0, 5, 6}, {42, 52, 10, 4}});
+ auto cmp = builder.Lt(v, m, /*broadcast_dimensions=*/{1});
+
+ const string expected = R"(pred[2,4] {
+ { 0011 },
+ { 1110 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+TEST_F(ArrayElementwiseOpTest, Mul2Dby1DF32) {
+ // Test simple broadcasting of a R1F32 over R2F32 when the order of binary op
+ // arguments is reversed.
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantR2<float>({{1.5f, 2.5f, 3.5f}, {4.5f, 5.5f, 6.5f}});
+ auto v = builder.ConstantR1<float>({2.0f, 4.0f, 6.0f});
+ auto add = builder.Mul(m, v, /*broadcast_dimensions=*/{1});
+ Array2D<float> expected_array({{3.0f, 10.0f, 21.0f}, {9.0f, 22.0f, 39.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim1) {
+ // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
+ ComputationBuilder builder(client_, TestName());
+ // m's shape in XLA notation is {3, 2}
+ // md's shape in XLA notation is {3, 1}
+ // The result has shape {3, 2}, where md is broadcast over m
+ auto m =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto md = builder.ConstantR2<float>({{10.0f, 20.0f, 30.0f}});
+ auto add = builder.Add(m, md);
+ Array2D<float> expected_array(
+ {{7.5f, 23.14f, 31.0f}, {12.25f, 10.0f, 33.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo2DWithDegenerateDim0) {
+ // Tests broadcasting for arrays with degenerate (size == 1) dimensions.
+ ComputationBuilder builder(client_, TestName());
+ // m's shape in XLA notation is {3, 2}
+ // md's shape in XLA notation is {1, 2}
+ // The result has shape {3, 2}, where md is broadcast over m
+ auto m =
+ builder.ConstantR2<float>({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}});
+ auto md = builder.ConstantR2<float>({{10.0f}, {20.0f}});
+ auto add = builder.Add(m, md);
+ Array2D<float> expected_array(
+ {{7.5f, 13.14f, 11.0f}, {22.25f, 10.0f, 23.33f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add2DsWithDegenerateDimsOuterProduct) {
+ // Tests broadcasting for two degenerate arrays. This kind of broadcasting
+ // effectively creates an "outer product" operation.
+ // This is taken from the Numpy docs example at:
+ // http://docs.scipy.org/doc/numpy-1.10.1/user/basics.broadcasting.html
+ ComputationBuilder builder(client_, TestName());
+ // a's shape in XLA notation is {1, 4}
+ // b's shape in XLA notation is {3, 1}
+ // The result has shape {3, 4}.
+ auto a = builder.ConstantR2<float>({{0.0f}, {10.0f}, {20.0f}, {30.0f}});
+ auto b = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
+ auto add = builder.Add(a, b);
+ Array2D<float> expected_array({{1.0f, 2.0f, 3.0f},
+ {11.0f, 12.0f, 13.0f},
+ {21.0f, 22.0f, 23.0f},
+ {31.0f, 32.0f, 33.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver1) {
+ // Add together a (2,2) array and a (2) array, using dimension 0 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({20.0f, 40.0f});
+ auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
+ auto add = builder.Add(v, m, /*broadcast_dimensions=*/{1});
+ Array2D<float> expected_array({{30.0f, 90.0f}, {97.0f, 128.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32TwoWaysOver0) {
+ // Add together a (2,2) array and a (2) array, using dimension 1 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({20.0f, 40.0f});
+ auto m = builder.ConstantR2<float>({{10.0f, 50.0f}, {77.0f, 88.0f}});
+ auto add = builder.Add(v, m, /*broadcast_dimensions=*/{0});
+ Array2D<float> expected_array({{30.0f, 70.0f}, {117.0f, 128.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, 3DBinaryOpF32s) {
+ // Binary add of two R3s together
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+
+ Array3D<float> b_3d({{{2.0f, 4.0f}, {6.0f, 8.0f}, {10.0f, 12.0f}},
+ {{14.0f, 16.0f}, {18.0f, 20.0f}, {22.0f, 24.0f}}});
+ auto b = builder.ConstantR3FromArray3D<float>(b_3d);
+ auto add = builder.Add(a, b);
+
+ Array3D<float> expected_3d(
+ {{{3.0f, 6.0f}, {9.0f, 12.0f}, {15.0f, 18.0f}},
+ {{21.0f, 24.0f}, {27.0f, 30.0f}, {33.0f, 36.0f}}});
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver2) {
+ // Add together a (2, 3, 2) array with a (2) array, using dimension 0 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array3D<float> a_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ // clang-format on
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto v = builder.ConstantR1<float>({10.0f, 20.0f});
+ auto add = builder.Add(a, v, /*broadcast_dimensions=*/{2});
+
+ Array3D<float> expected_3d(
+ {{{11.0f, 22.0f}, {13.0f, 24.0f}, {15.0f, 26.0f}},
+ {{17.0f, 28.0f}, {19.0f, 30.0f}, {21.0f, 32.0f}}});
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo3DTwoWaysOver0) {
+ // Add together a (2, 3, 2) array with a (2) array, using dimension 2 for
+ // broadcasting (though there are two ways to broadcast these shapes).
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array3D<float> a_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ // clang-format on
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto v = builder.ConstantR1<float>({10.0f, 20.0f});
+ auto add = builder.Add(a, v, /*broadcast_dimensions=*/{0});
+
+ // clang-format off
+ Array3D<float> expected_3d({
+ {{11.0f, 12.0f},
+ {13.0f, 14.0f},
+ {15.0f, 16.0f}},
+ {{27.0f, 28.0f},
+ {29.0f, 30.0f},
+ {31.0f, 32.0f}},
+ });
+ // clang-format on
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, Add2DTo3D) {
+ // Add together a (2, 3, 2) array with a (3, 2) array, using dimensions {1,2}
+ // for broadcasting.
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array3D<float> a_3d({
+ {{1.0f, 2.0f},
+ {3.0f, 4.0f},
+ {5.0f, 6.0f}},
+ {{7.0f, 8.0f},
+ {9.0f, 10.0f},
+ {11.0f, 12.0f}},
+ });
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+ auto m = builder.ConstantR2<float>({
+ {10.0f, 20.0f, 30.0f},
+ {40.0f, 50.0f, 60.0f},
+ });
+ auto add = builder.Add(a, m, /*broadcast_dimensions=*/{0, 1});
+
+ Array3D<float> expected_3d({
+ {{11.0f, 12.0f},
+ {23.0f, 24.0f},
+ {35.0f, 36.0f}},
+ {{47.0f, 48.0f},
+ {59.0f, 60.0f},
+ {71.0f, 72.0f}},
+ });
+ // clang-format on
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, error_spec_);
+}
+
+XLA_TEST_F(ArrayElementwiseOpTest, CompareGtR3F32sWithDegenerateDim2) {
+ // Comparison between two 3D arrays of compatible shapes:
+ // (2, 3, 2) and (2, 3, 1): expected to produce a (2, 3, 2) shape of PREDs.
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_3d({{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}});
+ auto a = builder.ConstantR3FromArray3D<float>(a_3d);
+
+ Array3D<float> b_3d({{{7.0f, 1.0f}, {3.0f, 10.0f}, {15.0f, 6.0f}}});
+ auto b = builder.ConstantR3FromArray3D<float>(b_3d);
+
+ auto compare = builder.Gt(a, b);
+
+ Array3D<int> expected_3d(
+ {{{0, 1}, {0, 0}, {0, 0}}, {{0, 1}, {1, 0}, {0, 1}}});
+ const string expected = R"(pred[2,3,2] {
+{ { 01 },
+ { 00 },
+ { 00 } },
+{ { 01 },
+ { 10 },
+ { 01 } }
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+TEST_F(ArrayElementwiseOpTest, 4DBinaryOpF32s) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
+ std::unique_ptr<Array4D<float>> operand_b_4d(new Array4D<float>(2, 3, 4, 5));
+ std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
+ float value = 0.0;
+ for (int64 p = 0; p < 2; ++p) {
+ for (int64 z = 0; z < 3; ++z) {
+ for (int64 y = 0; y < 4; ++y) {
+ for (int64 x = 0; x < 5; ++x) {
+ (*operand_a_4d)(p, z, y, x) = value;
+ (*operand_b_4d)(p, z, y, x) = 2.0 * value;
+ (*expected_4d)(p, z, y, x) = 3.0 * value;
+ value += 0.1;
+ }
+ }
+ }
+ }
+
+ auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
+ auto b = builder.ConstantR4FromArray4D<float>(*operand_b_4d);
+ auto add = builder.Add(a, b);
+
+ ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, R4PlusR1InDim1) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Array4D<float>> operand_a_4d(new Array4D<float>(2, 3, 4, 5));
+ std::unique_ptr<Array4D<float>> expected_4d(new Array4D<float>(2, 3, 4, 5));
+ std::vector<float> operand_b_1d(3);
+ std::iota(operand_b_1d.begin(), operand_b_1d.end(), 1.0);
+
+ float value = 0.0;
+ for (int64 p = 0; p < 2; ++p) {
+ for (int64 z = 0; z < 3; ++z) {
+ for (int64 y = 0; y < 4; ++y) {
+ for (int64 x = 0; x < 5; ++x) {
+ (*operand_a_4d)(p, z, y, x) = value;
+ (*expected_4d)(p, z, y, x) = value + operand_b_1d[z];
+ value += 0.1;
+ }
+ }
+ }
+ }
+
+ auto a = builder.ConstantR4FromArray4D<float>(*operand_a_4d);
+ auto b = builder.ConstantR1<float>(operand_b_1d);
+ auto add = builder.Add(a, b, {1});
+
+ ComputeAndCompareR4<float>(&builder, *expected_4d, {}, error_spec_);
+}
+
+TEST_F(ArrayElementwiseOpTest, R4_32x64x2x2_Plus_R1_64) {
+ constexpr int d0 = 16;
+ constexpr int d1 = 16;
+ constexpr int d2 = 2;
+ constexpr int d3 = 2;
+ Array4D<float> r4(d0, d1, d2, d3);
+ r4.Fill(1.0);
+ std::vector<float> r1(d1);
+ std::iota(r1.begin(), r1.end(), 1.0);
+
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR4FromArray4D(r4);
+ *a_literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout({0, 1, 2, 3});
+ auto a = builder.ConstantLiteral(*a_literal);
+ auto b = builder.ConstantR1<float>(r1);
+ builder.Add(a, b, {1});
+
+ for (int i0 = 0; i0 < d0; ++i0) {
+ for (int i1 = 0; i1 < d1; ++i1) {
+ for (int i2 = 0; i2 < d2; ++i2) {
+ for (int i3 = 0; i3 < d3; ++i3) {
+ r4(i0, i1, i2, i3) += r1[i1];
+ }
+ }
+ }
+ }
+ ComputeAndCompareR4<float>(&builder, r4, {}, error_spec_);
+}
+
+// Show that we can't add two opaques.
+TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) {
+ ComputationBuilder builder(client_, TestName());
+ auto shape = ShapeUtil::MakeOpaqueShape();
+ auto x = builder.Parameter(0, shape, "x");
+ auto concatenated = builder.Add(x, x);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(computation_status.status().ToString(),
+ testing::ContainsRegex(
+ "Expected non-opaque argument for lhs of binary operation"));
+}
+
+// Regression test for b/31927799. "slice - y" is fused and requires implicit
+// broadcast.
+TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
+ ComputationBuilder builder(client_, TestName());
+ auto x_literal = LiteralUtil::CreateR1<float>({1, 2, 3});
+ auto y_literal = LiteralUtil::CreateR1<float>({4, 5});
+ auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+
+ auto x = builder.Parameter(0, x_literal->shape(), "x");
+ auto y = builder.Parameter(1, y_literal->shape(), "y");
+ auto slice = builder.Slice(x, {1}, {2});
+ builder.Sub(slice, y);
+
+ ComputeAndCompareR1<float>(&builder, {-2, -3}, {x_data.get(), y_data.get()},
+ error_spec_);
+}
+
+INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
+ ArrayElementwiseOpTestParamCount,
+ ::testing::Values(127, 128, 129, 17 * 4096));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::legacy_flags::AppendLlvmBackendFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
new file mode 100644
index 0000000000..adffac09e3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class AxpySimpleTest : public ClientLibraryTestBase {};
+
+TEST_F(AxpySimpleTest, AxTenValues) {
+ ComputationBuilder builder(client_, "ax_10");
+ auto alpha = builder.ConstantR0<float>(3.1415926535);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Mul(alpha, x);
+
+ std::vector<float> expected = {
+ -3.14159265, 3.14159265, 6.28318531, -6.28318531, -9.42477796,
+ 9.42477796, 12.56637061, -12.56637061, -15.70796327, 15.70796327};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
+ ComputationBuilder builder(client_, "axpy_10");
+ auto alpha = builder.ConstantR0<float>(3.1415926535);
+ auto x = builder.ConstantR1<float>({});
+ auto y = builder.ConstantR1<float>({});
+ auto ax = builder.Mul(alpha, x);
+ auto axpy = builder.Add(ax, y);
+
+ std::vector<float> expected = {};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(AxpySimpleTest, AxpyTenValues) {
+ ComputationBuilder builder(client_, "axpy_10");
+ auto alpha = builder.ConstantR0<float>(3.1415926535);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto y = builder.ConstantR1<float>(
+ {5.0, -5.0, -4.0, 4.0, 3.0, -3.0, -2.0, 2.0, 1.0, -1.0});
+ auto ax = builder.Mul(alpha, x);
+ auto axpy = builder.Add(ax, y);
+
+ std::vector<float> expected = {
+ 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796,
+ 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
new file mode 100644
index 0000000000..c7b533b80f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that passing a bad shape to RNG's output parameter causes a validation
+// failure rather than causing a crash.
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class BadRngShapeValidationTest : public ClientLibraryTestBase {};
+
+TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto one = builder.ConstantR0<float>(1.0);
+ Shape default_constructed;
+ builder.RngUniform(zero, one, default_constructed);
+
+ StatusOr<Computation> computation = builder.Build();
+ EXPECT_FALSE(computation.ok());
+ LOG(INFO) << "status received: " << computation.status();
+ EXPECT_MATCH(computation.status().error_message(),
+ testing::HasSubstr("shape has invalid"));
+}
+
+TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto one = builder.ConstantR0<float>(1.0);
+ Shape sans_layout;
+ sans_layout.set_element_type(F32);
+ sans_layout.add_dimensions(1);
+
+ builder.RngUniform(zero, one, sans_layout);
+
+ StatusOr<Computation> computation = builder.Build();
+ ASSERT_TRUE(computation.ok());
+ LOG(INFO) << computation.status();
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
new file mode 100644
index 0000000000..598fd69909
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class BatchNormalizationTest : public ClientLibraryTestBase {
+ protected:
+ BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
+ Array2D<float> pz({
+ // z0 z1
+ {-1.0f, 4.1f}, // p0
+ {2.0f, 4.1f}, // p1
+ {5.0f, 4.4f}, // p2
+ });
+ input_array_.FillWithPZ(pz);
+ input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_);
+ CHECK_EQ(kSamples, input_array_.planes());
+ CHECK_EQ(kZ, input_array_.depth());
+ CHECK_EQ(kY, input_array_.height());
+ CHECK_EQ(kY, input_array_.width());
+ }
+
+ static constexpr int64 kSamples = 3;
+ static constexpr int64 kX = 1;
+ static constexpr int64 kY = 1;
+ static constexpr int64 kZ = 2;
+
+ Array4D<float> input_array_;
+ Literal input_literal_;
+ const ErrorSpec error_spec_{0.001, 0.001};
+};
+
+TEST_F(BatchNormalizationTest, SubtractInZ) {
+ ComputationBuilder builder(client_, "subtract_in_z_one_sample");
+ auto x = builder.ConstantLiteral(input_literal_);
+ auto y = builder.ConstantR1<float>({3.14, 4.25});
+ builder.Sub(x, y, /*broadcast_dimensions=*/{1});
+
+ Array4D<float> expected(kSamples, kZ, kY, kX);
+ Array2D<float> pz({
+ {-1.0f - 3.14f, 4.1f - 4.25f}, // p0
+ {2.0f - 3.14f, 4.1f - 4.25f}, // p1
+ {5.0f - 3.14f, 4.4f - 4.25f}, // p2
+ });
+ expected.FillWithPZ(pz);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, SquareTesseractElementwise) {
+ ComputationBuilder builder(client_, "square_tesseract_elementwise");
+ auto x = builder.ConstantLiteral(input_literal_);
+ builder.SquareF32(x);
+
+ Array4D<float> expected(kSamples, kZ, kY, kX);
+ Array2D<float> expected_pz({
+ {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)},
+ {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)},
+ {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)},
+ });
+ expected.FillWithPZ(expected_pz);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, SumToZ) {
+ ComputationBuilder builder(client_, "sum_to_z");
+ auto input_activations = builder.ConstantLiteral(input_literal_);
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ // Reduce all but the Z dimension.
+ builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
+ {0, 2, 3});
+
+ std::vector<float> expected = {6, 12.6};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, SquareAndReduce) {
+ ComputationBuilder builder(client_, "square_and_reduce");
+ auto input_activations = builder.ConstantLiteral(input_literal_);
+ auto set_means = builder.ConstantR1<float>({2.f, 4.2f});
+ auto activation_deviations = builder.Sub(input_activations, set_means,
+ /*broadcast_dimensions=*/{1});
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ auto dev_squares = builder.SquareF32(activation_deviations);
+ auto sum_of_squares = builder.Reduce(
+ dev_squares, builder.ConstantR0<float>(0.0f), add, {0, 2, 3});
+
+ std::vector<float> expected = {18, 0.06};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(BatchNormalizationTest, VarianceToStddev) {
+ ComputationBuilder builder(client_, "variance_to_stddev");
+ auto variance = builder.ConstantR1<float>({6.f, .02f});
+ auto sqrt = builder.SqrtF32(variance);
+
+ std::vector<float> expected = {2.44948974f, 0.14142136f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+// Compare against a forward batch normalization example in the NN spec
+// reference.
+TEST_F(BatchNormalizationTest, SpecComparisonForward) {
+ ComputationBuilder builder(client_, "batch_normalize_per_spec");
+ auto input_activations =
+ builder.CheckShape(builder.ConstantLiteral(input_literal_),
+ ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
+ auto gamma = builder.ConstantR1<float>({1.0, 1.0});
+ auto beta = builder.ConstantR1<float>({0.0, 0.0});
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ // Reduce all dimensions except dimension 1.
+ Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
+ auto sum = builder.CheckShape(
+ builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0, 2, 3}),
+ TwoElementVectorF32);
+ auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
+ auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
+ auto count = builder.ConstantR0<float>(ShapeUtil::ElementsIn(*input_shape) /
+ ShapeUtil::ElementsIn(*sum_shape));
+ auto set_means = builder.Div(sum, count);
+
+ const float kEpsilon = 1e-9f;
+ auto epsilon = builder.ConstantR0<float>(kEpsilon);
+ auto epsilon2 = builder.ConstantR1<float>({kEpsilon, kEpsilon});
+ auto activation_deviations = builder.Sub(input_activations, set_means,
+ /*broadcast_dimensions=*/{1});
+ auto dev_squares = builder.SquareF32(activation_deviations);
+ auto sum_of_squares = builder.CheckShape(
+ builder.Reduce(dev_squares, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0, 2, 3}),
+ TwoElementVectorF32);
+ auto variance = builder.Div(sum_of_squares, count);
+ auto standard_deviation = builder.SqrtF32(variance);
+ auto standard_deviation_above_epsilon = builder.CheckShape(
+ builder.Gt(standard_deviation, epsilon), ShapeUtil::MakeShape(PRED, {2}));
+ auto gt_eps = builder.Select(standard_deviation_above_epsilon,
+ standard_deviation, epsilon2);
+ auto normalization_factors = builder.ReciprocalF32(gt_eps);
+ auto normalized_input_activations =
+ builder.Mul(activation_deviations, normalization_factors,
+ /*broadcast_dimensions=*/{1});
+ /* auto output_activations = */ builder.Add(
+ builder.Mul(normalized_input_activations, gamma,
+ /*broadcast_dimensions=*/{1}),
+ beta, /*broadcast_dimensions=*/{1});
+
+ Array4D<float> expected(kSamples, kZ, kY, kX);
+ Array2D<float> pz({
+ {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
+ {0.f, -.1f / std::sqrt(.02f)},
+ {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
+ });
+ expected.FillWithPZ(pz);
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
new file mode 100644
index 0000000000..e825bd435b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
@@ -0,0 +1,157 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class BinopScalingTest : public ClientLibraryTestBase {};
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(0, col);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 129);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(0, col);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 9, 5);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(row, 0);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) {
+ auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 257);
+ auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1);
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ builder.Add(lhs, rhs);
+
+ auto aexpected = ReferenceUtil::MapWithIndexArray2D(
+ *alhs, [&](float lhs_value, int64 row, int64 col) {
+ return lhs_value + (*arhs)(row, 0);
+ });
+ ComputeAndCompareR2<float>(&builder, *aexpected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, R0PlusR2F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR0<float>(42.0);
+ auto rhs = builder.ConstantR2<float>({
+ {1.0, 2.0}, {3.0, 4.0},
+ });
+ builder.Add(lhs, rhs);
+
+ Array2D<float> expected(2, 2);
+ expected(0, 0) = 42.0 + 1.0;
+ expected(0, 1) = 42.0 + 2.0;
+ expected(1, 0) = 42.0 + 3.0;
+ expected(1, 1) = 42.0 + 4.0;
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(BinopScalingTest, R4PlusR0S32) {
+ ComputationBuilder builder(client_, TestName());
+ // clang-format off
+ Array4D<int> lhs_array({
+ {{{1, 2},
+ {3, 4},
+ {5, 6}}},
+ {{{7, 8},
+ {9, 10},
+ {11, 12}}},
+ });
+ Array4D<int> expected({
+ {{{43, 44},
+ {45, 46},
+ {47, 48}}},
+ {{{49, 50},
+ {51, 52},
+ {53, 54}}},
+ });
+ // clang-format on
+
+ auto lhs = builder.ConstantR4FromArray4D(lhs_array);
+ auto rhs = builder.ConstantR0<int>(42);
+ builder.Add(lhs, rhs);
+ ComputeAndCompareR4<int>(&builder, expected, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
new file mode 100644
index 0000000000..200d4d4563
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+using BroadcastSimpleTest = ClientLibraryTestBase;
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(1.5), {});
+ ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
+ Array2D<float> expected(2, 3, 2.25);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
+ Array2D<float> expected(2, 0);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
+ Array2D<float> expected(0, 2);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
+
+ Array2D<float> expected(2, 3);
+ expected(0, 0) = 1;
+ expected(0, 1) = 2;
+ expected(0, 2) = 3;
+ expected(1, 0) = 1;
+ expected(1, 1) = 2;
+ expected(1, 2) = 3;
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR1<float>({}), {2});
+
+ Array2D<float> expected(2, 0);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
+ ComputationBuilder b(client_, TestName());
+ b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
+
+ Array2D<float> expected(0, 3);
+ ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
+ // Verify that binary op and degenerate dimension broadcast work together in
+ // the same operation.
+ //
+ // The lhs shape [1, 2] is first broadcast up to [2, 1, 2] using in-dimension
+ // broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
+ // [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
+ // dimensions.
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
+ b.ConstantLiteral(*LiteralUtil::CreateR3<float>(
+ {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
+ /*broadcast_dimensions=*/{1, 2});
+
+ auto expected =
+ LiteralUtil::CreateR3<float>({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}},
+ {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}});
+
+ ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
+ // Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
+ // results in a shape incompatible with the lhs [2, 3, 1].
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
+ b.ConstantLiteral(*LiteralUtil::CreateR3<float>(
+ {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})),
+ /*broadcast_dimensions=*/{1, 2});
+
+ auto result_status = Execute(&b, {});
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(result_status.status().error_message(),
+ testing::ContainsRegex("broadcast dimension 0 mismatch"));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
+ // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
+ b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
+
+ auto result_status = Execute(&b, {});
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(
+ result_status.status().error_message(),
+ testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes"));
+}
+
+XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
+ // Test invalid broadcasting with [1, 2] and [2, 3] inputs.
+ ComputationBuilder b(client_, TestName());
+
+ b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
+ b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
+
+ auto result_status = Execute(&b, {});
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(
+ result_status.status().error_message(),
+ testing::ContainsRegex("binary op BINOP_ADD with incompatible shapes"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc
new file mode 100644
index 0000000000..1796a732e5
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/broadcast_test.cc
@@ -0,0 +1,286 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class BroadcastTest : public HloTestBase {};
+
+XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
+ // Test degenerate case of broadcasting a scalar into a scalar.
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {}), input, {}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0<float>(42.0), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+
+ // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
+ // to enable testing of the results.
+ auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
+ auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
+ builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
+ result->tuple_literals(0), error_spec_);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
+ result->tuple_literals(1), error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
+ // Degenerately broadcasting a shape into a shape of the same rank reorders
+ // the dimensions, ie transpose.
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), *result,
+ error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
+ *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
+
+ // Broadcast vector in dimension 1.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(2, 2, 3, 3);
+ Array2D<float> pz({{1, 2}, {1, 2}});
+ expected.FillWithPZ(pz);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
+ auto builder = HloComputation::Builder(TestName());
+ std::vector<float> input_data(1025);
+ int64 r1_size = input_data.size();
+ std::iota(input_data.begin(), input_data.end(), 0.0f);
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
+
+ // Broadcast vector in dimension 3.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(3, 3, 3, 1025);
+ Array2D<float> yx(/*height=*/3, /*width=*/r1_size);
+ for (int64 y = 0; y < 3; ++y) {
+ for (int64 x = 0; x < r1_size; ++x) {
+ yx(y, x) = input_data[x];
+ }
+ }
+ expected.FillWithYX(yx);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
+ auto builder = HloComputation::Builder(TestName());
+ Array4D<float> r4_array(32, 64, 7, 7);
+ r4_array.Fill(42.0);
+ std::vector<float> r1_array(64, 42.0);
+
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
+
+ // Broadcast vector in dimension 1.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR4FromArray4D(r4_array),
+ *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto input = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ LOG(INFO) << hlo_module->ToString();
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(64, 64, 3, 3);
+ expected.Fill(1.0f);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
+ auto builder = HloComputation::Builder(TestName());
+ Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ auto input = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
+
+ // Broadcast vector in dimensions 2 and 3.
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
+
+ // Create HLO module, compile, and execute.
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ hlo_module->AddEntryComputation(builder.Build());
+ auto result = ExecuteAndTransfer(std::move(hlo_module), {});
+
+ Array4D<float> expected(3, 3, 2, 2);
+ expected.FillWithYX(to_broadcast);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR4FromArray4D<float>(expected), *result, error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
new file mode 100644
index 0000000000..2c7eeb820d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -0,0 +1,149 @@
+"""Build rules for XLA testing."""
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
+
+def all_backends():
+ if cuda_is_configured():
+ return ["cpu", "cpu_parallel", "gpu"]
+ else:
+ return ["cpu", "cpu_parallel"]
+
+def xla_test(name,
+ srcs,
+ deps,
+ backends=[],
+ args=[],
+ tags=[],
+ copts=[],
+ backend_tags={},
+ backend_args={},
+ **kwargs):
+ """Generates cc_test targets for the given XLA backends.
+
+ This rule generates a cc_test target for one or more XLA backends and also
+ a platform-agnostic cc_library rule. The arguments are identical to cc_test
+ with two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "cpu_parallel", "gpu"), and
+ 'backend_args'/'backend_tags' specifies backend-specific args parameters to
+ use when generating the cc_test.
+
+ The name of the cc_tests are the provided name argument with the backend name
+ appended, and the cc_library target name is the provided name argument with
+ "_lib" appended. For example, if name parameter is "foo_test", then the cpu
+ test target will be "foo_test_cpu" and the cc_library target is "foo_lib".
+
+ The cc_library target can be used to link with other plugins outside of
+ xla_test.
+
+ The build rule also defines a test suite ${name} which includes the tests for
+ each of the supported backends.
+
+ Each generated cc_test target has a tag indicating which backend the test is
+ for. This tag is of the form "xla_${BACKEND}" (eg, "xla_cpu"). These
+ tags can be used to gather tests for a particular backend into a test_suite.
+
+ Examples:
+
+ # Generates the targets: foo_test_cpu and foo_test_gpu.
+ xla_test(
+ name = "foo_test",
+ srcs = ["foo_test.cc"],
+ backends = ["cpu", "gpu"],
+ deps = [...],
+ )
+
+ # Generates the targets: bar_test_cpu and bar_test_gpu. bar_test_cpu
+ # includes the additional arg "--special_cpu_flag".
+ xla_test(
+ name = "bar_test",
+ srcs = ["bar_test.cc"],
+ backends = ["cpu", "gpu"],
+ backend_args = {"cpu": ["--special_cpu_flag"]}
+ deps = [...],
+ )
+
+ The build rule defines the preprocessor macro XLA_TEST_BACKEND_${BACKEND}
+ to the value 1 where ${BACKEND} is the uppercase name of the backend.
+
+ Args:
+ name: Name of the target.
+ srcs: Sources for the target.
+ deps: Dependencies of the target.
+ backends: A list of backends to generate tests for. Supported
+ values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will
+ be generated for all supported backends.
+ args: Test arguments for the target.
+ tags: Tags for the target.
+ backend_args: A dict mapping backend name to list of additional args to
+ use for that target.
+ backend_tags: A dict mapping backend name to list of additional tags to
+ use for that target.
+ """
+ test_names = []
+ if not backends:
+ backends = all_backends()
+
+ native.cc_library(
+ name="%s_lib" % name,
+ srcs=srcs,
+ copts=copts,
+ testonly=True,
+ deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"],
+ )
+
+ for backend in backends:
+ test_name = "%s_%s" % (name, backend)
+ this_backend_tags = ["xla_%s" % backend]
+ this_backend_copts = []
+ this_backend_args = backend_args.get(backend, [])
+ if backend == "cpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ elif backend == "cpu_parallel":
+ backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
+ this_backend_args += ["--xla_cpu_parallel=true"]
+ elif backend == "gpu":
+ backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
+ backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
+ this_backend_tags += ["requires-gpu-sm35"]
+ else:
+ fail("Unknown backend %s" % backend)
+
+ native.cc_test(
+ name=test_name,
+ srcs=srcs,
+ tags=tags + backend_tags.get(backend, []) + this_backend_tags,
+ copts=copts + ["-DXLA_TEST_BACKEND_%s=1" % backend.upper()] +
+ this_backend_copts,
+ args=args + this_backend_args,
+ deps=deps + backend_deps,
+ **kwargs)
+
+ test_names.append(test_name)
+
+ native.test_suite(name=name, tests=test_names)
+
+
+def generate_backend_suites(backends=[]):
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.test_suite(name="%s_tests" % backend,
+ tags = ["xla_%s" % backend])
+
+
+def generate_backend_test_macros(backends=[]):
+ if not backends:
+ backends = all_backends()
+ for backend in backends:
+ native.cc_library(
+ name="test_macros_%s" % backend,
+ testonly = True,
+ hdrs = ["test_macros.h"],
+ copts = ["-DXLA_PLATFORM=\\\"%s\\\"" % backend.upper()],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ])
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
new file mode 100644
index 0000000000..1c96b73034
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CallOpTest : public ClientLibraryTestBase {
+ protected:
+ Computation CreateR0F32IdentityComputation() {
+ ComputationBuilder builder(client_, "Identity");
+ builder.Parameter(0, r0f32_, "x");
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateR1S0F32AdditionComputation() {
+ ComputationBuilder builder(client_, "Addition");
+ auto x = builder.Parameter(0, r1s0f32_, "x");
+ auto y = builder.Parameter(1, r1s0f32_, "y");
+ builder.Add(x, y);
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateR1S2F32AdditionComputation() {
+ ComputationBuilder builder(client_, "Addition");
+ auto x = builder.Parameter(0, r1s2f32_, "x");
+ auto y = builder.Parameter(1, r1s2f32_, "y");
+ builder.Add(x, y);
+ auto build_status = builder.Build();
+ EXPECT_IS_OK(build_status.status());
+ return build_status.ConsumeValueOrDie();
+ }
+
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+ Shape r1s0f32_ = ShapeUtil::MakeShape(F32, {0});
+ Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
+};
+
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) {
+ ComputationBuilder builder(client_, TestName());
+ Computation callee = CreateR0F32IdentityComputation();
+ auto constant = builder.ConstantLiteral(*LiteralUtil::CreateR0<float>(42.0));
+ builder.Call(callee, {constant});
+
+ ComputeAndCompareR0<float>(&builder, 42.0, {}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) {
+ ComputationBuilder builder(client_, TestName());
+ Computation callee = CreateR1S0F32AdditionComputation();
+ auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({}));
+ auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({}));
+ builder.Call(callee, {x, y});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) {
+ ComputationBuilder builder(client_, TestName());
+ Computation callee = CreateR1S2F32AdditionComputation();
+ auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({1.0f, 2.0f}));
+ auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1<float>({2.0f, 3.0f}));
+ builder.Call(callee, {x, y});
+
+ ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
new file mode 100644
index 0000000000..675c9fccb0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CheckExecutionArityTest : public ClientLibraryTestBase {};
+
+TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) {
+ ComputationBuilder builder(client_, "add_two_params");
+ auto param_literal = LiteralUtil::CreateR1<float>({1.1f, 2.2f});
+
+ auto p0 = builder.Parameter(0, param_literal->shape(), "param0");
+ auto p1 = builder.Parameter(1, param_literal->shape(), "param1");
+ auto add = builder.Add(p0, p1);
+
+ auto param0_data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+ auto param1_data =
+ client_->TransferToServer(*param_literal).ConsumeValueOrDie();
+
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ auto computation = computation_status.ConsumeValueOrDie();
+
+ // The arity of the UserComputation is 2 arguments. Execution will succeed
+ // with 2 arguments, but fail with a different number.
+ auto result_two_args =
+ client_->Execute(computation, {param0_data.get(), param1_data.get()});
+ ASSERT_IS_OK(result_two_args.status());
+
+ auto result_one_arg = client_->Execute(computation, {param0_data.get()});
+ ASSERT_FALSE(result_one_arg.ok());
+ ASSERT_EQ(result_one_arg.status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(result_one_arg.status().error_message(),
+ testing::ContainsRegex("takes 2"));
+
+ auto result_zero_args = client_->Execute(computation, {});
+ ASSERT_FALSE(result_zero_args.ok());
+ ASSERT_EQ(result_zero_args.status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(result_zero_args.status().error_message(),
+ testing::ContainsRegex("takes 2"));
+}
+
+XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) {
+ ComputationBuilder builder(client_, "add_two_params");
+
+ auto p0 = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto p1 = builder.Parameter(1, ShapeUtil::MakeShape(F32, {4}), "param1");
+ auto add = builder.Mul(p0, p1);
+
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ auto computation = computation_status.ConsumeValueOrDie();
+
+ auto f32_literal = LiteralUtil::CreateR0<float>(1.1f);
+ auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie();
+ auto f32_4_literal = LiteralUtil::CreateR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto f32_4_data =
+ client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie();
+ auto u8_4_literal = LiteralUtil::CreateR1U8("hola");
+ auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie();
+
+ // Match
+ auto status =
+ client_->Execute(computation, {f32_data.get(), f32_4_data.get()});
+ ASSERT_IS_OK(status.status());
+
+ // Shape mismatch in parameter 0
+ status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()});
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(status.status().error_message(),
+ testing::ContainsRegex("expects parameter 0"));
+
+ // Shape mismatch in parameter 1 (rank)
+ status = client_->Execute(computation, {f32_data.get(), f32_data.get()});
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(status.status().error_message(),
+ testing::ContainsRegex("expects parameter 1"));
+
+ // Shape mismatch in parameter 1 (element type)
+ status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()});
+ ASSERT_FALSE(status.ok());
+ ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT);
+ ASSERT_MATCH(status.status().error_message(),
+ testing::ContainsRegex("expects parameter 1"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
new file mode 100644
index 0000000000..d2a7def5d0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -0,0 +1,263 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_pass_pipeline_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+// Wrapper function that creates a nicer error message (than a bare
+// ValueOrDie()) if the platform we intend to test is not available.
+Client* GetOrCreateLocalClientOrDie(se::Platform* platform) {
+ StatusOr<Client*> result = ClientLibrary::GetOrCreateLocalClient(platform);
+ TF_CHECK_OK(result.status()) << "could not create local client for testing";
+ return result.ValueOrDie();
+}
+} // namespace
+
+ClientLibraryTestBase::ClientLibraryTestBase(
+ se::Platform* platform,
+ tensorflow::gtl::ArraySlice<string> disabled_pass_names)
+ : client_(GetOrCreateLocalClientOrDie(platform)) {
+ legacy_flags::HloPassPipelineFlags* flags =
+ legacy_flags::GetHloPassPipelineFlags();
+ flags->xla_disable_hlo_passes =
+ tensorflow::str_util::Join(disabled_pass_names, ",");
+}
+
+string ClientLibraryTestBase::TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+}
+
+StatusOr<std::unique_ptr<GlobalData>> ClientLibraryTestBase::Execute(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ // Build the computation, as a convenience.
+ TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
+ return client_->Execute(computation, arguments);
+}
+
+StatusOr<std::unique_ptr<Literal>> ClientLibraryTestBase::ExecuteAndTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout) {
+ // Build the computation, as a convenience.
+ TF_ASSIGN_OR_RETURN(auto computation, builder->Build());
+ return client_->ExecuteAndTransfer(computation, arguments,
+ shape_with_output_layout);
+}
+
+std::unique_ptr<GlobalData> ClientLibraryTestBase::ExecuteOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ return Execute(builder, arguments).ConsumeValueOrDie();
+}
+
+std::unique_ptr<Literal> ClientLibraryTestBase::ExecuteAndTransferOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ return ExecuteAndTransfer(builder, arguments).ConsumeValueOrDie();
+}
+
+string ClientLibraryTestBase::ExecuteToString(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ StatusOr<Computation> computation_status = builder->Build();
+ if (!computation_status.ok()) {
+ return computation_status.status().ToString();
+ }
+ Computation computation = computation_status.ConsumeValueOrDie();
+
+ auto result = client_->ExecuteAndTransfer(computation, arguments);
+ if (!result.ok()) {
+ return result.status().ToString();
+ } else {
+ return LiteralUtil::ToString(*result.ValueOrDie());
+ }
+}
+
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ ComputationBuilder* builder, const tensorflow::core::Bitmap& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout) {
+ EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
+ shape_with_layout));
+}
+
+void ClientLibraryTestBase::ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout) {
+ EXPECT_IS_OK(ComputeAndCompareLiteralWithStatus(builder, expected, arguments,
+ error, shape_with_layout));
+}
+
+tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout) {
+ TF_ASSIGN_OR_RETURN(
+ auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout));
+ if (ShapeUtil::ElementIsFloating(expected.shape())) {
+ LOG(WARNING) << "performing exact comparison of floating point numbers";
+ } else {
+ TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) ||
+ expected.shape().element_type() == PRED);
+ }
+ LiteralTestUtil::ExpectEqual(expected, *actual);
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout) {
+ TF_ASSIGN_OR_RETURN(
+ auto actual, ExecuteAndTransfer(builder, arguments, shape_with_layout));
+ TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()));
+ LiteralTestUtil::ExpectNear(expected, *actual, error);
+ return tensorflow::Status::OK();
+}
+
+void ClientLibraryTestBase::ComputeAndCompareR1U8(
+ ComputationBuilder* builder, tensorflow::StringPiece expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ auto actual_status = ExecuteAndTransfer(builder, arguments);
+ EXPECT_IS_OK(actual_status.status());
+ if (!actual_status.ok()) {
+ return;
+ }
+ auto actual = actual_status.ConsumeValueOrDie();
+
+ // Turn the expected value into a literal.
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR1U8(expected);
+
+ VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(*actual);
+
+ EXPECT_EQ(expected, actual->u8s());
+}
+
+void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ auto actual_status = ExecuteAndTransfer(builder, arguments);
+ EXPECT_IS_OK(actual_status.status());
+ if (!actual_status.ok()) {
+ return;
+ }
+ auto actual = actual_status.ConsumeValueOrDie();
+ LiteralTestUtil::ExpectEqualTuple(expected, *actual);
+}
+
+void ClientLibraryTestBase::ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ auto actual_status = ExecuteAndTransfer(builder, arguments);
+ EXPECT_IS_OK(actual_status.status());
+ if (!actual_status.ok()) {
+ return;
+ }
+ auto actual = actual_status.ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNearTuple(expected, *actual, error);
+}
+
+Computation ClientLibraryTestBase::CreateScalarRelu() {
+ ComputationBuilder builder(client_, "relu");
+ auto z_value = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Max(z_value, zero);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+}
+
+Computation ClientLibraryTestBase::CreateScalarMax() {
+ ComputationBuilder builder(client_, "max");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Max(x, y);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+}
+
+Computation ClientLibraryTestBase::CreateScalarReluSensitivity() {
+ ComputationBuilder builder(client_, "relu_sensitivity");
+ auto activation =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "activation");
+ auto backprop =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "backprop");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto activation_gtz = builder.Gt(activation, zero);
+ builder.Select(activation_gtz, /*on_true=*/backprop, /*on_false=*/zero);
+
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+}
+
+std::unique_ptr<Array2D<float>> ClientLibraryTestBase::CreatePatternedMatrix(
+ int rows, int cols, float offset) {
+ auto array = MakeUnique<Array2D<float>>(rows, cols);
+ for (int64 row = 0; row < rows; ++row) {
+ for (int64 col = 0; col < cols; ++col) {
+ (*array)(row, col) = col + (row * 1000.0f) + offset;
+ }
+ }
+ return array;
+}
+
+std::unique_ptr<Array2D<float>>
+ClientLibraryTestBase::CreatePatternedMatrixWithZeroPadding(int rows, int cols,
+ int rows_padded,
+ int cols_padded) {
+ CHECK_GE(rows_padded, rows);
+ CHECK_GE(cols_padded, cols);
+ auto array = MakeUnique<Array2D<float>>(rows_padded, cols_padded, 0.0);
+ for (int64 row = 0; row < rows; ++row) {
+ for (int64 col = 0; col < cols; ++col) {
+ (*array)(row, col) = col + (row * 1000.0f);
+ }
+ }
+ return array;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
new file mode 100644
index 0000000000..690fda3ffa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -0,0 +1,409 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
+
+#include <memory>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/bitmap.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// A client library test establishes an in-process XLA client connection.
+class ClientLibraryTestBase : public ::testing::Test {
+ protected:
+ explicit ClientLibraryTestBase(
+ perftools::gputools::Platform* platform = nullptr,
+ tensorflow::gtl::ArraySlice<string> disabled_pass_names = {});
+
+ // Returns the name of the test currently being run.
+ string TestName() const;
+
+ // TODO(b/25566808): Add helper that populates a literal from a testdata file.
+
+ // Convenience methods for building and running a computation from a builder.
+ StatusOr<std::unique_ptr<GlobalData>> Execute(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ StatusOr<std::unique_ptr<Literal>> ExecuteAndTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_output_layout = nullptr);
+
+ // Convenience OrDie variants of above methods.
+ std::unique_ptr<GlobalData> ExecuteOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ std::unique_ptr<Literal> ExecuteAndTransferOrDie(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ // Run a computation and return its value as a string. If an error
+ // occurs, then instead return the error as a string.
+ string ExecuteToString(ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ // Convenience methods for building and running a computation, transferring
+ // the result, and comparing it to the expected value(s). Methods are
+ // templated on the native host type which maps to specific XLA types (See
+ // ComputationBuilder for details). For each rank, two forms are provided: one
+ // for floating point types with an ErrorSpec parameter, and one for integral
+ // types without the ErrorSpec parameter.
+ template <typename NativeT>
+ void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR0(ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ template <typename NativeT>
+ void ComputeAndCompareR1(ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR1(ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ // As above, but uses a bitmap to hold the predicate vector to avoid
+ // deficiencies of vector<bool>.
+ void ComputeAndCompareR1(ComputationBuilder* builder,
+ const tensorflow::core::Bitmap& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ template <typename NativeT>
+ void ComputeAndCompareR2(ComputationBuilder* builder,
+ const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR2(ComputationBuilder* builder,
+ const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ template <typename NativeT>
+ void ComputeAndCompareR3(ComputationBuilder* builder,
+ const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR3(ComputationBuilder* builder,
+ const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ template <typename NativeT>
+ void ComputeAndCompareR4(ComputationBuilder* builder,
+ const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ template <typename NativeT>
+ void ComputeAndCompareR4(ComputationBuilder* builder,
+ const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ ErrorSpec error);
+
+ // Build and run the computation and compare the result with the given
+ // literal. shape_with_layout indicates the result layout to request when
+ // calling Execute.
+ void ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout = nullptr);
+ void ComputeAndCompareLiteral(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout = nullptr);
+
+ // ComputeAndCompare variant which returns an error status.
+ tensorflow::Status ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ const Shape* shape_with_layout = nullptr);
+ tensorflow::Status ComputeAndCompareLiteralWithStatus(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error,
+ const Shape* shape_with_layout = nullptr);
+
+ // Compare the result of the computation to a strings. In XLA strings are
+ // represented using rank-1 U8 shapes.
+ void ComputeAndCompareR1U8(
+ ComputationBuilder* builder, tensorflow::StringPiece expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+
+ // Convenience method for running a built computation, transferring the
+ // result, and comparing it to the expected tuple literal.
+ void ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ void ComputeAndCompareTuple(
+ ComputationBuilder* builder, const Literal& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
+
+ // Create scalar operations for use in reductions.
+ Computation CreateScalarRelu();
+ Computation CreateScalarMax();
+ Computation CreateScalarReluSensitivity();
+
+ // Special case convenience functions for creating filled arrays.
+
+ // Creates an array of pseudorandom values lying between the given minimum and
+ // maximum values.
+ template <typename NativeT>
+ std::vector<NativeT> CreatePseudorandomR1(const int width, NativeT min_value,
+ NativeT max_value, uint32 seed);
+ template <typename NativeT>
+ std::unique_ptr<Array2D<NativeT>> CreatePseudorandomR2(const int rows,
+ const int cols,
+ NativeT min_value,
+ NativeT max_value,
+ uint32 seed);
+
+ // Creates a (rows x cols) array filled in the following form:
+ //
+ // [ 0 1 ... cols-1]
+ // [ 1,000 1,001 ... 1000.0 + cols-1]
+ // [ ... ... ... ...]
+ // [(rows-1)*1000.0 ... ... (rows-1)*1000.0 + cols-1]
+ //
+ // If provided, offset is added uniformly to every element (e.g. an offset of
+ // 64 would cause 0 in the above to be 64, 1 to be 65, 1000 to be 1064, etc.)
+ std::unique_ptr<Array2D<float>> CreatePatternedMatrix(const int rows,
+ const int cols,
+ float offset = 0.0);
+
+ // Creates a (rows x cols) array as above, padded out to
+ // (rows_padded x cols_padded) with zeroes. Requires rows_padded >= rows
+ // and cols_padded > cols.
+ std::unique_ptr<Array2D<float>> CreatePatternedMatrixWithZeroPadding(
+ const int rows, const int cols, const int rows_padded,
+ const int cols_padded);
+
+ // Create a parameter instruction that wraps the given values and then stores
+ // into "data_handle" the global handle for that parameter.
+ //
+ // "parameter_number" is the parameter number.
+ // "name" is the name of the parameter instruction.
+ template <typename NativeT>
+ std::unique_ptr<GlobalData> CreateR1Parameter(
+ tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle);
+
+ // Create a parameter instruction that wraps the given constant array
+ // "array_2d" and then stores to "data_handle" the global handle for that
+ // parameter.
+ //
+ // "parameter_number" is the parameter number.
+ // "name" is the name of the parameter instruction.
+ template <typename NativeT>
+ std::unique_ptr<GlobalData> CreateR2Parameter(
+ const Array2D<NativeT>& array_2d, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle);
+
+ Client* client_;
+};
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR0(
+ ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR0(
+ ComputationBuilder* builder, NativeT expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR0<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ ComputationBuilder* builder, tensorflow::gtl::ArraySlice<NativeT> expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR1<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR2(
+ ComputationBuilder* builder, const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR2(
+ ComputationBuilder* builder, const Array2D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR2FromArray2D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR3(
+ ComputationBuilder* builder, const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR3(
+ ComputationBuilder* builder, const Array3D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR3FromArray3D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR4(
+ ComputationBuilder* builder, const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
+template <typename NativeT>
+void ClientLibraryTestBase::ComputeAndCompareR4(
+ ComputationBuilder* builder, const Array4D<NativeT>& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error) {
+ static_assert(std::is_same<NativeT, float>::value ||
+ std::is_same<NativeT, double>::value,
+ "Floating point type required when specifying an ErrorSpec");
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR4FromArray4D<NativeT>(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments, error);
+}
+
+template <typename NativeT>
+std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR1Parameter(
+ tensorflow::gtl::ArraySlice<NativeT> values, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1(values);
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ return data;
+}
+
+template <typename NativeT>
+std::unique_ptr<GlobalData> ClientLibraryTestBase::CreateR2Parameter(
+ const Array2D<NativeT>& array_2d, int64 parameter_number,
+ const string& name, ComputationBuilder* builder,
+ ComputationDataHandle* data_handle) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2FromArray2D(array_2d);
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ *data_handle = builder->Parameter(parameter_number, literal->shape(), name);
+ return data;
+}
+
+template <typename NativeT>
+std::vector<NativeT> ClientLibraryTestBase::CreatePseudorandomR1(
+ const int width, NativeT min_value, NativeT max_value, uint32 seed) {
+ std::vector<NativeT> result(width);
+ test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
+ seed);
+ for (int i = 0; i < width; ++i) {
+ result[i] = generator.get();
+ }
+ return result;
+}
+
+template <typename NativeT>
+std::unique_ptr<Array2D<NativeT>> ClientLibraryTestBase::CreatePseudorandomR2(
+ const int rows, const int cols, NativeT min_value, NativeT max_value,
+ uint32 seed) {
+ auto result = MakeUnique<Array2D<NativeT>>(rows, cols);
+ test_utils::PseudorandomGenerator<NativeT> generator(min_value, max_value,
+ seed);
+ for (int y = 0; y < rows; ++y) {
+ for (int x = 0; x < cols; ++x) {
+ (*result)(y, x) = generator.get();
+ }
+ }
+ return result;
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_CLIENT_LIBRARY_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
new file mode 100644
index 0000000000..77b85af83c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -0,0 +1,127 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ClientTest : public ClientLibraryTestBase {};
+
+TEST_F(ClientTest, ExecuteWithLayout) {
+ ComputationBuilder b(client_, TestName());
+
+ std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
+ for (const std::vector<int64>& execute_layout : layouts) {
+ for (const std::vector<int64>& transfer_layout : layouts) {
+ b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}}));
+ auto computation = b.Build();
+ ASSERT_TRUE(computation.ok()) << computation.status();
+
+ const Shape execute_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
+ S32, /*dimensions=*/{2, 2}, execute_layout);
+ std::unique_ptr<GlobalData> data =
+ client_
+ ->Execute(computation.ValueOrDie(), {},
+ &execute_shape_with_layout)
+ .ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> expected_literal =
+ test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
+ transfer_layout);
+
+ auto computed = client_->Transfer(*data, &expected_literal->shape());
+
+ LiteralTestUtil::AssertEqualShapesAndLayouts(
+ expected_literal->shape(), computed.ValueOrDie()->shape());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ }
+ }
+}
+
+TEST_F(ClientTest, ExecuteWithTupleLayout) {
+ ComputationBuilder b(client_, TestName());
+
+ b.Tuple({b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}})});
+
+ auto computation = b.Build();
+ ASSERT_TRUE(computation.ok()) << computation.status();
+
+ // Create a result shape with one element column major and the other row
+ // major.
+ Shape tuple_shape_with_layout = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{0, 1}),
+ ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{1, 0})});
+
+ auto result = client_
+ ->ExecuteAndTransfer(computation.ValueOrDie(), {},
+ &tuple_shape_with_layout)
+ .ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR2Equal<int32>({{1, 2}, {3, 4}},
+ result->tuple_literals(0));
+ LiteralTestUtil::ExpectR2Equal<int32>({{10, 20}, {30, 40}},
+ result->tuple_literals(1));
+
+ EXPECT_TRUE(ShapeUtil::IsTuple(result->shape()));
+ EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape()));
+
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetTupleElementShape(result->shape(), 0),
+ ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{0, 1})));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ ShapeUtil::GetTupleElementShape(result->shape(), 1),
+ ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
+ /*minor_to_major=*/{1, 0})));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc
new file mode 100644
index 0000000000..fe4dff2109
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc
@@ -0,0 +1,90 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/codegen_test_base.h"
+
+#include <stdlib.h>
+#include <utility>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/subprocess.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
+ const string& pattern) {
+ std::unique_ptr<Executable> executable =
+ CompileToExecutable(std::move(hlo_module));
+ string ir_module_string = GetIrFromExecutable(*executable);
+ RunFileCheck(ir_module_string, pattern);
+}
+
+std::unique_ptr<Executable> CodegenTestBase::CompileToExecutable(
+ std::unique_ptr<HloModule> hlo_module) {
+ auto module_config = MakeUnique<HloModuleConfig>(
+ MakeProgramShape(hlo_module->entry_computation()));
+ return backend_->compiler()
+ ->Compile(std::move(hlo_module), std::move(module_config),
+ test_hlo_dumper_, backend_->default_stream_executor())
+ .ConsumeValueOrDie();
+}
+
+void CodegenTestBase::RunFileCheck(const string& input, const string& pattern) {
+ // Write input to a temporary file.
+ char tempdir_template[] = "/tmp/ir_testXXXXXX";
+ char* tempdir_name = mkdtemp(tempdir_template);
+ CHECK_NOTNULL(tempdir_name);
+ string pattern_path =
+ tensorflow::io::JoinPath(tempdir_name, "xla_hlo_test_ir_pattern");
+ TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
+ pattern_path, pattern));
+
+ // Invoke FileCheck to check whether input matches `pattern`.
+ tensorflow::SubProcess file_check_process;
+ const char* test_srcdir = getenv("TEST_SRCDIR");
+ if (test_srcdir == nullptr) {
+ test_srcdir = ".";
+ }
+ string file_check_path = tensorflow::io::JoinPath(
+ test_srcdir, "external/llvm/FileCheck");
+ file_check_process.SetProgram(file_check_path,
+ {file_check_path, pattern_path});
+ file_check_process.SetChannelAction(tensorflow::CHAN_STDIN,
+ tensorflow::ACTION_PIPE);
+ file_check_process.SetChannelAction(tensorflow::CHAN_STDERR,
+ tensorflow::ACTION_PIPE);
+ CHECK(file_check_process.Start());
+ string standard_error;
+ int exit_status = file_check_process.Communicate(
+ /*stdin_input=*/&input, /*stdout_output=*/nullptr,
+ /*stderr_output=*/&standard_error);
+
+ // FileCheck returns 0 when the inputs match. If matching failed, we output
+ // the error message generated by FileCheck.
+ SCOPED_TRACE(tensorflow::strings::StrCat("Input to FileCheck:\n", input));
+ EXPECT_EQ(0, exit_status) << standard_error;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h
new file mode 100644
index 0000000000..50c0453107
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/codegen_test_base.h
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+namespace xla {
+
+// Tests that verify IR emitted by the CPU/GPU backend is as expected.
+class CodegenTestBase : public HloTestBase {
+ protected:
+ CodegenTestBase() {}
+
+ // Returns the embedded LLVM IR from the given executable. Codegen tests must
+ // override this method, but execution tests do not have to because they do
+ // not examine the embedded IR.
+ virtual string GetIrFromExecutable(const Executable& executable) = 0;
+
+ // Compiles the given HLO module to LLVM IR and verifies the IR matches the
+ // given pattern. `pattern` is in the FileCheck pattern matching syntax
+ // (http://llvm.org/docs/CommandGuide/FileCheck.html).
+ void CompileAndVerifyIr(std::unique_ptr<HloModule> hlo_module,
+ const string& pattern);
+
+ protected:
+ // Compiles hlo_module to an executable, CHECK-failing if this fails.
+ std::unique_ptr<Executable> CompileToExecutable(
+ std::unique_ptr<HloModule> hlo_module);
+
+ // Runs FileCheck with the given pattern over the given string and EXPECTs
+ // that FileCheck succeeded in matching the input.
+ void RunFileCheck(const string& input, const string& pattern);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_CODEGEN_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
new file mode 100644
index 0000000000..38ce007cb0
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc
@@ -0,0 +1,218 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class CompilationCacheTest : public ClientLibraryTestBase {
+ public:
+ void ExecuteComputationR0F32(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, float expected_result,
+ bool expect_cache_hit) {
+ ExecutionProfile execution_profile;
+ std::unique_ptr<Literal> result =
+ client_
+ ->ExecuteAndTransfer(computation, arguments,
+ /*output_layout=*/nullptr, &execution_profile)
+ .ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0<float>(expected_result),
+ *result, error_spec_);
+ EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
+ }
+
+ void ExecuteComputationR2F32(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments,
+ std::initializer_list<std::initializer_list<float>> expected_result,
+ bool expect_cache_hit) {
+ ExecutionProfile execution_profile;
+ auto data_handle =
+ client_
+ ->Execute(computation, arguments, /*output_layout=*/nullptr,
+ &execution_profile)
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> result =
+ client_->Transfer(*data_handle).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>(expected_result),
+ *result, error_spec_);
+ EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit());
+ }
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<float>(42.0));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) {
+ std::unique_ptr<GlobalData> data_42 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(42.0f))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> data_123 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(123.0f))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> data_456 =
+ client_->TransferToServer(*LiteralUtil::CreateR0<float>(456.0f))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {data_123.get()}, -123.0,
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {data_456.get()}, -456.0,
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR0F32(computation, {data_42.get()}, -42.0,
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, MultipleComputations) {
+ ComputationBuilder builder_neg(client_, TestName() + "_neg");
+ builder_neg.Neg(builder_neg.ConstantR0<float>(42.0));
+ Computation computation_neg = builder_neg.Build().ConsumeValueOrDie();
+
+ ComputationBuilder builder_exp(client_, TestName() + "_exp");
+ builder_exp.Exp(builder_exp.ConstantR0<float>(1.0));
+ Computation computation_exp = builder_exp.Build().ConsumeValueOrDie();
+
+ ComputationBuilder builder_add(client_, TestName() + "_add");
+ builder_add.Add(builder_add.ConstantR0<float>(2.0),
+ builder_add.ConstantR0<float>(3.0));
+ Computation computation_add = builder_add.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation_neg, {}, -42.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_exp, {}, 2.7182817,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_add, {}, 5.0,
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation_neg, {}, -42.0,
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, DifferentParameterLayouts) {
+ // Create two GlobalData arrays with the same shape but different
+ // layouts. Use these arrays as parameters to a simple computation. If the
+ // layout of the array changes then computation should be recompiled (cache
+ // miss).
+ auto rowmaj_array = test_utils::CreateR2LiteralWithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{1, 0});
+ auto rowmaj_handle =
+ client_->TransferToServer(*rowmaj_array).ConsumeValueOrDie();
+
+ auto colmaj_array = test_utils::CreateR2LiteralWithLayout(
+ {{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1});
+ auto colmaj_handle =
+ client_->TransferToServer(*colmaj_array).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/false);
+ ExecuteComputationR2F32(computation, {rowmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+ ExecuteComputationR2F32(computation, {colmaj_handle.get()},
+ {{1.0f, 2.0f}, {3.0f, 4.0f}},
+ /*expect_cache_hit=*/true);
+}
+
+XLA_TEST_F(CompilationCacheTest, MutatedComputation) {
+ // Build a computation, execute it, then mutate it. The mutated computation
+ // should not be in the cache until it is run once. This must be done through
+ // the stub interface because Computations built from ComputationBuilder are
+ // immutable.
+ ComputationBuilder builder(client_, TestName());
+ auto neg = builder.Neg(builder.ConstantR0<float>(42.0));
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -42.0, /*expect_cache_hit=*/true);
+
+ BinaryOpRequest request;
+ request.set_binop(BINOP_ADD);
+ *request.mutable_lhs() = neg;
+ *request.mutable_rhs() = neg;
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation.handle();
+ *op_request.mutable_binary_op_request() = request;
+ OpResponse response;
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ ASSERT_TRUE(s.ok());
+
+ ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/false);
+ ExecuteComputationR0F32(computation, {}, -84.0, /*expect_cache_hit=*/true);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc
new file mode 100644
index 0000000000..709ce5029c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc
@@ -0,0 +1,249 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ComputeConstantTest : public ClientLibraryTestBase {
+ public:
+ StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
+ ComputationDataHandle operand, ComputationBuilder* builder,
+ Layout* output_layout = nullptr) {
+ TF_ASSIGN_OR_RETURN(auto remote_computed,
+ builder->ComputeConstant(operand, output_layout));
+ TF_ASSIGN_OR_RETURN(auto computed, client_->Transfer(*remote_computed));
+ return std::move(computed);
+ }
+
+ template <class Scalar>
+ StatusOr<Scalar> ComputeConstantScalar(ComputationDataHandle operand,
+ ComputationBuilder* builder) {
+ TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(operand, builder));
+ return LiteralUtil::Get<Scalar>(*literal, {});
+ }
+
+ bool IsConstant(const ComputationDataHandle& operand,
+ ComputationBuilder* builder) {
+ StatusOr<bool> result = builder->IsConstant(operand);
+ EXPECT_TRUE(result.ok()) << result.status();
+ return result.ok() ? result.ValueOrDie() : false;
+ }
+
+ template <class Scalar>
+ void ExpectConstantComputedScalar(ComputationDataHandle operand,
+ Scalar expected,
+ ComputationBuilder* builder) {
+ Scalar computed = ComputeConstantScalar<Scalar>(operand, builder);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0(expected);
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed);
+ }
+};
+
+TEST_F(ComputeConstantTest, ScalarInt32Literal) {
+ ComputationBuilder b(client_, TestName());
+ auto computation = b.ConstantR0<int32>(42);
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<int32>(computation, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ EXPECT_EQ(value.ValueOrDie(), 42);
+}
+
+TEST_F(ComputeConstantTest, ScalarFloatAdd) {
+ ComputationBuilder b(client_, TestName());
+ auto computation =
+ b.Add(b.ConstantR0<float>(42.5f), b.ConstantR0<float>(1.5f));
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ EXPECT_EQ(value.ValueOrDie(), 44.0f);
+}
+
+TEST_F(ComputeConstantTest, ScalarRng) {
+ ComputationBuilder b(client_, TestName());
+ auto computation =
+ b.RngUniform(b.ConstantR0<float>(1.1f), b.ConstantR0<float>(2.1f),
+ ShapeUtil::MakeShape(F32, {}));
+ EXPECT_FALSE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ ASSERT_FALSE(value.ok())
+ << "computing a RNG value should not be considered a constant";
+}
+
+TEST_F(ComputeConstantTest, DirectParam) {
+ ComputationBuilder b(client_, TestName());
+ auto computation = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param");
+ EXPECT_FALSE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
+ .contains("depends on parameter"))
+ << value.status();
+}
+
+TEST_F(ComputeConstantTest, IndirectParam) {
+ ComputationBuilder b(client_, TestName());
+ auto computation =
+ b.Add(b.ConstantR0<float>(1.0f),
+ b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param"));
+ EXPECT_FALSE(IsConstant(computation, &b));
+
+ auto value = ComputeConstantScalar<float>(computation, &b);
+ EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString())
+ .contains("depends on parameter"))
+ << value.status();
+}
+
+// Test computation of an expression interspersed with param nodes but
+// the expression does not depend on the param nodes.
+TEST_F(ComputeConstantTest, UnrelatedParam) {
+ ComputationBuilder b(client_, TestName());
+
+ auto param_a = b.Parameter(10, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto constant_4 = b.Add(b.ConstantR0<float>(2.5f), b.ConstantR0<float>(1.5f));
+ auto not_constant_a = b.Add(constant_4, param_a);
+
+ auto param_b = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "param1");
+ auto constant_9 = b.Mul(b.ConstantR0<float>(2.0f), b.ConstantR0<float>(4.5f));
+ auto not_constant_b = b.Add(param_b, constant_9);
+
+ auto constant_13 = b.Add(constant_4, constant_9);
+ b.Add(not_constant_b, b.Add(constant_13, not_constant_a));
+
+ EXPECT_TRUE(IsConstant(constant_13, &b));
+
+ auto value = ComputeConstantScalar<float>(constant_13, &b);
+ ASSERT_TRUE(value.ok()) << value.status();
+ EXPECT_EQ(value.ValueOrDie(), 13.0f);
+}
+
+TEST_F(ComputeConstantTest, NonScalarAdd) {
+ ComputationBuilder b(client_, TestName());
+
+ auto computation =
+ b.Add(b.ConstantR1<int32>({1, 2}), b.ConstantR1<int32>({3, 4}));
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto computed = ComputeConstantLiteral(computation, &b);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+ std::unique_ptr<Literal> expected_literal =
+ LiteralUtil::CreateR1<int32>({4, 6});
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+}
+
+TEST_F(ComputeConstantTest, IntegerDivide) {
+ ComputationBuilder b(client_, TestName());
+ auto computation = b.Div(b.ConstantR0<int32>(15), b.ConstantR0<int32>(3));
+ EXPECT_TRUE(IsConstant(computation, &b));
+
+ auto computed = ComputeConstantLiteral(computation, &b);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+ std::unique_ptr<Literal> expected_literal = LiteralUtil::CreateR0<int32>(5);
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+}
+
+XLA_TEST_F(ComputeConstantTest, Layout) {
+ ComputationBuilder b(client_, TestName());
+
+ std::vector<std::vector<int64>> layouts = {{0, 1}, {1, 0}};
+ for (const std::vector<int64>& layout : layouts) {
+ auto layout_proto = LayoutUtil::MakeLayout(layout);
+ auto computed =
+ ComputeConstantLiteral(b.Add(b.ConstantR2<int32>({{1, 2}, {3, 4}}),
+ b.ConstantR2<int32>({{10, 20}, {30, 40}})),
+ &b, &layout_proto);
+ ASSERT_TRUE(computed.ok()) << computed.status();
+
+ std::unique_ptr<Literal> expected_literal =
+ test_utils::CreateR2LiteralWithLayout<int32>({{11, 22}, {33, 44}},
+ layout);
+ LiteralTestUtil::AssertEqualShapesAndLayouts(
+ expected_literal->shape(), computed.ValueOrDie()->shape());
+ LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie());
+ }
+}
+
+// This test is permanently disabled on CPU because it requires that the
+// backend used for execution is different than the backend used for
+// ComputeConstant which is always cpu.
+TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) {
+ // Compute a trivial constant, then try to use the value in an Execute
+ // call. This should fail because the constant resides on the CPU and the
+ // Execute call is executed on a different backend.
+ ComputationBuilder constant_b(client_, TestName());
+ auto constant = constant_b.ConstantR0<int32>(42);
+ auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie();
+ auto literal = client_->Transfer(*handle).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectR0Equal(42, *literal);
+
+ // Build trivial computation which takes one parameter.
+ ComputationBuilder b(client_, TestName());
+ b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0"));
+ auto computation = b.Build().ConsumeValueOrDie();
+
+ // Try to use value from ComputeConstant in Execute.
+ auto execute_status = client_->Execute(computation, {handle.get()});
+ EXPECT_FALSE(execute_status.ok());
+ EXPECT_MATCH(
+ execute_status.status().error_message(),
+ testing::ContainsRegex("argument 0 is on device Host:0 but computation "
+ "will be executed on device"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc
new file mode 100644
index 0000000000..9a48b19b96
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/concat_test.cc
@@ -0,0 +1,523 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+using ConcatTest = ClientLibraryTestBase;
+
+// Concatenate expects at least one argument.
+XLA_TEST_F(ConcatTest, Concat_Nothing) {
+ ComputationBuilder builder(client_, TestName());
+ auto concatenated = builder.ConcatInDim({}, 0);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(
+ computation_status.status().ToString(),
+ testing::ContainsRegex("Concatenate expects at least one argument"));
+}
+
+// Concatenate with one argument works.
+XLA_TEST_F(ConcatTest, Concat_R1_With_Nothing) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0, 64.0});
+ auto concatenated = builder.ConcatInDim({a}, 0);
+
+ std::vector<float> expected = {42, 64};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Show that we can't concatenate R0 with R0 because we can't name the dimension
+// to concatenate on.
+XLA_TEST_F(ConcatTest, CannotConcatR0WithR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR0<float>(42.0);
+ auto b = builder.ConstantR0<float>(64.0);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(computation_status.status().ToString(),
+ testing::ContainsRegex(
+ "dimension to concatenate along out of bounds: 0"));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L0_With_R1_L1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto b = builder.ConstantR1<float>({256.0});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0, 64.0});
+ auto b = builder.ConstantR1<float>({});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {42, 64};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L2_With_R1_L1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0, 64.0});
+ auto b = builder.ConstantR1<float>({256.0});
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_L253_With_R1_L7) {
+ std::vector<float> lhs(253);
+ std::vector<float> rhs(7);
+ std::vector<float> expected(253 + 7);
+ for (int i = 0; i < 253; ++i) {
+ expected[i] = lhs[i] = i + 1;
+ }
+ for (int i = 0; i < 7; ++i) {
+ expected[253 + i] = rhs[i] = 253 + i + 1;
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>(lhs);
+ auto b = builder.ConstantR1<float>(rhs);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_0x0_With_0x0) {
+ for (int dim : {0, 1}) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
+ auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 0));
+ auto concatenated = builder.ConcatInDim({a, b}, dim);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {},
+ ErrorSpec(0.0001));
+ }
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(1, 1);
+ auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ Array2D<float> expected({
+ {0}, {64},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1_With_1x1_InDim1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(1, 1);
+ auto b_array = CreatePatternedMatrix(1, 1, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected({
+ {0, 64},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat2x0With2x5) {
+ ComputationBuilder builder(client_, TestName());
+ auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(Array2D<float>(2, 0));
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 1);
+
+ ComputeAndCompareR2<float>(&builder, *b_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat2x3With2x5) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(2, 3);
+ auto b_array = CreatePatternedMatrix(2, 5, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected({
+ {0, 1, 2, 64, 65, 66, 67, 68},
+ {1000, 1001, 1002, 1064, 1065, 1066, 1067, 1068},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat3x2With0x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(3, 2);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(Array2D<float>(0, 2));
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ ComputeAndCompareR2<float>(&builder, *a_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat3x2With5x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto a_array = CreatePatternedMatrix(3, 2);
+ auto b_array = CreatePatternedMatrix(5, 2, /*offset=*/64.0);
+ auto a = builder.ConstantR2FromArray2D(*a_array);
+ auto b = builder.ConstantR2FromArray2D(*b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 0);
+
+ Array2D<float> expected({
+ {0, 1},
+ {1000, 1001},
+ {2000, 2001},
+ {64, 65},
+ {1064, 1065},
+ {2064, 2065},
+ {3064, 3065},
+ {4064, 4065},
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R3_3x0x2_3x0x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 2));
+ auto b = builder.ConstantR3FromArray3D(Array3D<float>(3, 0, 1));
+ auto concatenated = builder.ConcatInDim({a, b}, 2);
+ ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 3), {},
+ ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1) {
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_array({
+ // 3x1x2
+ {{0, 1}},
+ {{2, 3}},
+ {{4, 5}},
+ });
+ Array3D<float> b_array({
+ // 3x1x1
+ {{6}},
+ {{7}},
+ {{8}},
+ });
+ auto a = builder.ConstantR3FromArray3D(a_array);
+ auto b = builder.ConstantR3FromArray3D(b_array);
+ auto concatenated = builder.ConcatInDim({a, b}, 2);
+
+ Array3D<float> expected({
+ {{0, 1, 6}}, {{2, 3, 7}}, {{4, 5, 8}},
+ });
+ ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R1_1x1_1x1_1x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0});
+ auto b = builder.ConstantR1<float>({64.0});
+ auto c = builder.ConstantR1<float>({256.0});
+ auto concatenated = builder.ConcatInDim({a, b, c}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_R3_3x1x2_3x1x1_3x1x1) {
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> a_array({
+ // 3x1x2
+ {{0, 1}},
+ {{4, 5}},
+ {{8, 9}},
+ });
+ Array3D<float> b_array({
+ // 3x1x1
+ {{2}},
+ {{6}},
+ {{10}},
+ });
+ Array3D<float> c_array({
+ // 3x1x1
+ {{3}},
+ {{7}},
+ {{11}},
+ });
+ auto a = builder.ConstantR3FromArray3D(a_array);
+ auto b = builder.ConstantR3FromArray3D(b_array);
+ auto c = builder.ConstantR3FromArray3D(c_array);
+ auto concatenated = builder.ConcatInDim({a, b, c}, 2);
+
+ Array3D<float> expected({
+ {{0, 1, 2, 3}}, {{4, 5, 6, 7}}, {{8, 9, 10, 11}},
+ });
+ ComputeAndCompareR3<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, DoubleConcatLeftAssociative) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0});
+ auto b = builder.ConstantR1<float>({64.0});
+ auto c = builder.ConstantR1<float>({256.0});
+ // concatenated = (a concat b) concat c
+ auto concatenated =
+ builder.ConcatInDim({builder.ConcatInDim({a, b}, 0), c}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, DoubleConcatRightAssociative) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0});
+ auto b = builder.ConstantR1<float>({64.0});
+ auto c = builder.ConstantR1<float>({256.0});
+ // concatenated = a concat (b concat c)
+ auto concatenated =
+ builder.ConcatInDim({a, builder.ConcatInDim({b, c}, 0)}, 0);
+
+ std::vector<float> expected = {42, 64, 256};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim0) {
+ Array2D<float> lhs(1, 1024);
+ Array2D<float> rhs(1, 1024);
+ for (int i = 0; i < 1024; ++i) {
+ lhs(0, i) = i;
+ rhs(0, i) = i + 1024;
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(lhs);
+ auto b = builder.ConstantR2FromArray2D<float>(rhs);
+ builder.ConcatInDim({a, b}, 0);
+
+ Array2D<float> expected(2, 1024);
+ for (int i = 0; i < 1024; ++i) {
+ expected(0, i) = i;
+ expected(1, i) = i + 1024;
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_1x1024_With_1x1024_InDim1) {
+ Array2D<float> lhs(1, 1024);
+ Array2D<float> rhs(1, 1024);
+ for (int i = 0; i < 1024; ++i) {
+ lhs(0, i) = i;
+ rhs(0, i) = i + 1024;
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(lhs);
+ auto b = builder.ConstantR2FromArray2D<float>(rhs);
+ builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected(1, 2048);
+ for (int i = 0; i < 1024; ++i) {
+ expected(0, i) = i;
+ expected(0, i + 1024) = i + 1024;
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConcatTest, Concat_64x64_With_64x2) {
+ Array2D<float> lhs(64, 64);
+ Array2D<float> rhs(64, 2);
+ for (int i0 = 0; i0 < 64; ++i0) {
+ for (int i1 = 0; i1 < 64; ++i1) {
+ lhs(i0, i1) = (i0 << 10) | i1;
+ }
+ for (int i1 = 0; i1 < 2; ++i1) {
+ rhs(i0, i1) = (i0 << 10) | (i1 + 64);
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(lhs);
+ auto b = builder.ConstantR2FromArray2D<float>(rhs);
+ builder.ConcatInDim({a, b}, 1);
+
+ Array2D<float> expected(64, 66);
+ for (int i0 = 0; i0 < 64; ++i0) {
+ for (int i1 = 0; i1 < 66; ++i1) {
+ expected(i0, i1) = (i0 << 10) | i1;
+ }
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Show that we can't concatenate with an opaques.
+XLA_TEST_F(ConcatTest, CannotConcatOpaques) {
+ ComputationBuilder builder(client_, TestName());
+ auto opaque_shape = ShapeUtil::MakeOpaqueShape();
+ auto r1f32 = xla::ShapeUtil::MakeShape(xla::F32, {1});
+ auto x = builder.Parameter(0, r1f32, "x");
+ auto y = builder.Parameter(1, opaque_shape, "y");
+ auto concatenated = builder.ConcatInDim({x, y}, 0);
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_FALSE(computation_status.ok());
+ EXPECT_MATCH(
+ computation_status.status().ToString(),
+ testing::ContainsRegex(
+ "Expected non-opaque argument for operand of concatenation"));
+}
+
+XLA_TEST_F(ConcatTest, ConcatSeveralBoxedPredicates) {
+ ComputationBuilder builder(client_, TestName());
+ auto p0 = builder.ConstantR1<bool>({true});
+ auto p1 = builder.ConstantR1<bool>({false});
+ auto p2 = builder.ConstantR1<bool>({true});
+ auto concatenated = builder.ConcatInDim({p0, p1, p2}, 0);
+
+ bool expected[] = {true, false, true};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConcatTest, ConcatSeveralR1S32s) {
+ ComputationBuilder builder(client_, TestName());
+ auto a0 = builder.ConstantR1<int32>({1});
+ auto a1 = builder.ConstantR1<int32>({2, 3});
+ auto a2 = builder.ConstantR1<int32>({4, 5, 6});
+ auto a3 = builder.ConstantR1<int32>({7, 8, 9, 10});
+ auto concatenated = builder.ConcatInDim({a0, a1, a2, a3}, 0);
+
+ std::vector<int32> expected(10);
+ std::iota(expected.begin(), expected.end(), 1);
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+// Describes a binary rank-2 concatenation test.
+struct R2BinarySpec {
+ int64 lhs_dim0;
+ int64 lhs_dim1;
+ int64 rhs_dim0;
+ int64 rhs_dim1;
+ int64 concat_dimension;
+};
+
+// TEST_P harness for binary rank-2 concatenation.
+class ConcatR2BinaryTest : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<R2BinarySpec> {
+};
+
+TEST_P(ConcatR2BinaryTest, DoIt) {
+ const R2BinarySpec& spec = GetParam();
+ Array2D<int32> lhs(spec.lhs_dim0, spec.lhs_dim1);
+ lhs.FillUnique();
+ Array2D<int32> rhs(spec.rhs_dim0, spec.rhs_dim1);
+ rhs.FillUnique(1000);
+
+ ComputationBuilder builder(client_, TestName());
+ auto a0 = builder.ConstantR2FromArray2D<int32>(lhs);
+ auto a1 = builder.ConstantR2FromArray2D<int32>(rhs);
+ builder.ConcatInDim({a0, a1}, spec.concat_dimension);
+
+ std::unique_ptr<Array2D<int32>> expected =
+ ReferenceUtil::Concat2D(lhs, rhs, spec.concat_dimension);
+ ComputeAndCompareR2<int32>(&builder, *expected, {});
+}
+
+// Regression test for b/31944287. x*y is used (at the same index) by all
+// operands of the concat. We should emit x*y in three incoming basic blocks of
+// the concat because these basic blocks are not control-equivalent.
+//
+// x*y
+// / | \
+// add1 add2 add3
+// \ | /
+// concat
+XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) {
+ auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {});
+ auto x_literal = LiteralUtil::CreateR0<float>(2.f);
+ auto y_literal = LiteralUtil::CreateR0<float>(3.f);
+ auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
+ auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, f32_scalar, "x");
+ auto y = builder.Parameter(1, f32_scalar, "y");
+ auto mul = builder.Mul(x, y);
+ auto add1 = builder.Add(mul, builder.ConstantR1<float>({1.f, 2.f}));
+ auto add2 = builder.Add(mul, builder.ConstantR1<float>({3.f, 4.f}));
+ auto add3 = builder.Add(mul, builder.ConstantR1<float>({5.f, 6.f}));
+ builder.ConcatInDim({add1, add2, add3}, /*dimension=*/0);
+
+ ComputeAndCompareR1<float>(&builder, {7., 8., 9., 10., 11., 12.},
+ {x_data.get(), y_data.get()}, ErrorSpec(1e-4));
+}
+
+INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest,
+ ::testing::Values(R2BinarySpec{1, 1, 1, 1, 0},
+ R2BinarySpec{1, 1, 1, 1, 1},
+ R2BinarySpec{4, 3, 4, 3, 0},
+ R2BinarySpec{4, 3, 4, 3, 1},
+ R2BinarySpec{7, 128, 1, 128, 0},
+ R2BinarySpec{8, 127, 8, 1, 1}));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc
new file mode 100644
index 0000000000..58d52ac116
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/constants_test.cc
@@ -0,0 +1,193 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that constants in program memory round trip as expected.
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConstantsTest : public ClientLibraryTestBase {
+ protected:
+ const ErrorSpec error_spec_{1e-3, 1e-5};
+};
+
+TEST_F(ConstantsTest, ZeroCellF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, OneCellF32) {
+ std::vector<float> constant = {2.0};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>(constant);
+
+ ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, OneCellS32) {
+ std::vector<int32> constant = {2};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<int32>(constant);
+
+ ComputeAndCompareR1<int32>(&builder, constant, {});
+}
+
+TEST_F(ConstantsTest, OneCellU32) {
+ std::vector<uint32> constant = {2};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<uint32>(constant);
+
+ ComputeAndCompareR1<uint32>(&builder, constant, {});
+}
+
+TEST_F(ConstantsTest, EightCells) {
+ std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>(constant);
+
+ ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, SixteenCells) {
+ std::vector<float> constant = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0,
+ 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0};
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>(constant);
+
+ ComputeAndCompareR1<float>(&builder, constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, Empty_0x2) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 2), {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, Small_2x2) {
+ std::unique_ptr<Array2D<float>> constant =
+ MakeLinspaceArray2D(100.0, 200.0, 2, 2);
+
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR2FromArray2D<float>(*constant);
+
+ ComputeAndCompareR2<float>(&builder, *constant, {}, error_spec_);
+}
+
+TEST_F(ConstantsTest, Empty_3x0x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto constant = builder.ConstantLiteral(
+ *LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(3, 0, 2)));
+
+ ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 2), {});
+}
+
+TEST_F(ConstantsTest, Small_2x2x2) {
+ ComputationBuilder builder(client_, TestName());
+ Array3D<float> array3d({
+ // x0 x1
+ {{1.f, 2.f}, // y0
+ {3.f, 4.f}}, // y1
+
+ {{5.f, 6.f}, // y0
+ {7.f, 8.f}}, // y1
+ });
+ auto constant = builder.ConstantLiteral(
+ *LiteralUtil::CreateR3FromArray3D<float>(array3d));
+
+ ComputeAndCompareR3<float>(&builder, array3d, {});
+}
+
+TEST_F(ConstantsTest, Small_3x2x1x1) {
+ Array4D<float> input_array(3, 2, 1, 1);
+ Array2D<float> pz({
+ // z0 z1
+ {-1.0f, 4.1f}, // p0
+ {2.0f, 4.1f}, // p1
+ {5.0f, 4.4f}, // p2
+ });
+ input_array.FillWithPZ(pz);
+ Literal input_literal = *LiteralUtil::CreateR4FromArray4D(input_array);
+
+ {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantLiteral(input_literal);
+ ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
+ }
+
+ {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR4FromArray4D<float>(input_array);
+ ComputeAndCompareR4<float>(&builder, input_array, {}, error_spec_);
+ }
+}
+
+// TODO(b/29263943): Support tuple constants.
+TEST_F(ConstantsTest, DISABLED_TupleConstant) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantLiteral(*LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>({{1.0}, {2.0}}).get(),
+ LiteralUtil::CreateR1<float>({2.0, 42}).get()}));
+
+ std::unique_ptr<Literal> result = ExecuteAndTransferOrDie(&builder, {});
+
+ LiteralTestUtil::ExpectR2Near<float>({{1.0}, {2.0}},
+ result->tuple_literals(0), error_spec_);
+ LiteralTestUtil::ExpectR1Near<float>({2.0, 42.0}, result->tuple_literals(1),
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc
new file mode 100644
index 0000000000..9f8c3a9aeb
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convert_test.cc
@@ -0,0 +1,210 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdint>
+#include <limits>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvertTest : public ClientLibraryTestBase {
+ public:
+ explicit ConvertTest(perftools::gputools::Platform* platform = nullptr)
+ : ClientLibraryTestBase(platform,
+ /*disabled_pass_names=*/{"algsimp", "inline"}) {}
+};
+
+TEST_F(ConvertTest, ConvertR1S32ToR1S32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({42, 64});
+ builder.ConvertElementType(a, S32);
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertR1F32ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.0f, 64.0f});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {42.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ConvertTest, ConvertR1S32ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({42, 64});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {42.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1S0S32ToR1S0F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>({});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ConvertTest, ConvertR1F32ToR1S32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({42.6, 64.4});
+ builder.ConvertElementType(a, S32);
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int64>({32, 64});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {32.0, 64.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint8_t>({32, 64});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {32.0, 64.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1U8ToR1S32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint8_t>({32, 64});
+ builder.ConvertElementType(a, S32);
+
+ std::vector<int32_t> expected = {32, 64};
+ ComputeAndCompareR1<int32_t>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1U8ToR1U32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<uint8_t>({32, 64});
+ builder.ConvertElementType(a, U32);
+
+ std::vector<uint32_t> expected = {32, 64};
+ ComputeAndCompareR1<uint32_t>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F64) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({32.0f, 64.0f});
+ builder.ConvertElementType(a, F64);
+
+ std::vector<double> expected = {32.0, 64.0};
+ ComputeAndCompareR1<double>(&builder, expected, {});
+}
+
+XLA_TEST_F(ConvertTest, ConvertR1F64ToR1F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<double>({32.0, 64.0});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {32.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertS32Extremes) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<int32>(
+ {std::numeric_limits<int32>::min(), std::numeric_limits<int32>::max()});
+ builder.ConvertElementType(a, F32);
+
+ std::vector<float> expected = {
+ static_cast<float>(std::numeric_limits<int32>::min()),
+ static_cast<float>(std::numeric_limits<int32>::max())};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ConvertTest, ConvertMapToS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto b = builder.CreateSubBuilder("convert");
+ auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in");
+ b->ConvertElementType(param, S32);
+ auto a = builder.ConstantR1<float>({42.0f, 64.0f});
+ builder.Map({a}, b->BuildAndNoteError());
+
+ std::vector<int32> expected = {42, 64};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(ConvertTest, ConvertMapToF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto b = builder.CreateSubBuilder("convert");
+ auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in");
+ b->ConvertElementType(param, F32);
+ auto a = builder.ConstantR1<int32>({42, 64});
+ builder.Map({a}, b->BuildAndNoteError());
+
+ std::vector<float> expected = {42.0f, 64.0f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Regression test for b/31758660. When ReshapeMover transforms
+// input -> reshape -> convert
+// to
+// input -> convert -> reshape
+// the new convert should have the same element type as the old convert.
+TEST_F(ConvertTest, ConvertReshape) {
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR1<int32>({42});
+ auto reshape = builder.Reshape(input, /*dimensions=*/{0}, /*new_sizes=*/{});
+ builder.ConvertElementType(reshape, F32);
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
new file mode 100644
index 0000000000..9f38dc4b36
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -0,0 +1,117 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <array>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {};
+
+// Tests the convolution operation with invalid input dimension numbers.
+TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) {
+ auto dimension_numbers_status =
+ ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3);
+ ASSERT_FALSE(dimension_numbers_status.ok());
+ ASSERT_MATCH(dimension_numbers_status.status().error_message(),
+ testing::ContainsRegex("input are not unique"));
+}
+
+// Tests the convolution operation with invalid weight dimension numbers.
+TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) {
+ auto dimension_numbers_status =
+ ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3);
+ ASSERT_FALSE(dimension_numbers_status.ok());
+ ASSERT_MATCH(dimension_numbers_status.status().error_message(),
+ testing::ContainsRegex("weight are not unique"));
+}
+
+XLA_TEST_F(ConvolutionDimensionNumbersTest,
+ TwoConvsWithDifferentDimensionNumbers) {
+ auto input_array = MakeUnique<Array4D<float>>(2, 3, 5, 5);
+ input_array->FillWithMultiples(0.1);
+ auto weight_array = MakeUnique<Array4D<float>>(4, 3, 1, 1);
+ weight_array->FillWithMultiples(0.2);
+ auto weight_data =
+ client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(*input_array);
+ auto weight =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight");
+ auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid);
+
+ ConvolutionDimensionNumbers dim_nums =
+ ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ // Swap batch_dimension and feature_dimension.
+ int64 tmp = dim_nums.batch_dimension();
+ dim_nums.set_batch_dimension(dim_nums.feature_dimension());
+ dim_nums.set_feature_dimension(tmp);
+ // Swap kernel_input_feature_dimension and kernel_output_feature_dimension.
+ tmp = dim_nums.kernel_input_feature_dimension();
+ dim_nums.set_kernel_input_feature_dimension(
+ dim_nums.kernel_output_feature_dimension());
+ dim_nums.set_kernel_output_feature_dimension(tmp);
+ builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid,
+ dim_nums);
+
+ auto expected_conv1 = ReferenceUtil::ConvArray4D(*input_array, *weight_array,
+ {1, 1}, Padding::kValid);
+ auto expected_conv2 = ReferenceUtil::ConvArray4DGeneralDimensions(
+ *input_array, *expected_conv1, {1, 1}, Padding::kValid, dim_nums);
+
+ ComputeAndCompareR4<float>(&builder, *expected_conv2, {weight_data.get()},
+ ErrorSpec(0.001, 0.01));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
new file mode 100644
index 0000000000..ffbda89b94
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -0,0 +1,361 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests of convolution with trivial kernels and no special variations (like
+// strides and padding).
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionTest : public ClientLibraryTestBase {
+ protected:
+#if XLA_TEST_BACKEND_GPU
+ // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
+ // convolution. So relax the absolute error threshold.
+ ErrorSpec error_spec_ = ErrorSpec(1e-3);
+#else
+ ErrorSpec error_spec_ = ErrorSpec(1e-4);
+#endif
+};
+
+XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
+ const int kInputActivationSizeY = 3;
+ const int kInputActivationSizeX = 3;
+ const int kInputActivationSizeZ = 256;
+ const int kKernelSizeX = 2;
+ const int kKernelSizeY = 2;
+ const int kOutputActivationSizeZ = 256;
+ const int kMiniBatchSize = 4;
+ auto alhs =
+ MakeUnique<Array4D<float>>(kMiniBatchSize, kInputActivationSizeZ,
+ kInputActivationSizeY, kInputActivationSizeX);
+ alhs->FillWithMultiples(1.0f);
+ ASSERT_EQ(3, alhs->width());
+ ASSERT_EQ(3, alhs->height());
+
+ auto arhs =
+ MakeUnique<Array4D<float>>(kOutputActivationSizeZ, kInputActivationSizeZ,
+ kKernelSizeY, kKernelSizeX);
+ Array2D<float> rhs_raster({
+ {1.0f, 0.0f}, // row 0
+ {0.0f, 0.0f}, // row 1
+ });
+ arhs->FillWithYX(rhs_raster);
+ ASSERT_EQ(2, arhs->width());
+ ASSERT_EQ(2, arhs->height());
+
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
+ auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
+ builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ }
+
+ Array4D<float> input(1, 1, 1, 2);
+ input.FillWithYX(Array2D<float>({
+ {1, 2},
+ }));
+ Array4D<float> filter(1, 1, 1, 2);
+ filter.FillWithYX(Array2D<float>({
+ {5, 6},
+ }));
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// Tests valid padding for 2D convolution in raster space.
+TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ }
+
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> filter(1, 1, 2, 2);
+ // clang-format off
+ filter.FillWithYX(Array2D<float>({
+ {5, 6},
+ {7, 8},
+ }));
+ // clang-format on
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// Tests same padding for 2D convolution in raster space.
+TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ }
+
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> filter(1, 1, 2, 2);
+ // clang-format off
+ filter.FillWithYX(Array2D<float>({
+ {5, 6},
+ {7, 8},
+ }));
+ // clang-format on
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// Tests same padding for 2D convolution in raster space with an odd sized
+// kernel.
+TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ }
+
+ Array4D<float> input(1, 1, 4, 4);
+ // clang-format off
+ input.FillWithYX(Array2D<float>({
+ {1, 2, 3, 4 },
+ {5, 6, 7, 8 },
+ {9, 10, 11, 12},
+ {13, 14, 15, 16},
+ }));
+ // clang-format on
+ Array4D<float> filter(1, 1, 3, 3);
+ // clang-format off
+ filter.FillWithYX(Array2D<float>({
+ { 5, 6, 7},
+ { 8, 9, 10},
+ {11, 12, 13},
+ }));
+ // clang-format on
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// TODO(b/32873825): implement 1D convolution on GPU.
+XLA_TEST_F(ConvolutionTest, DISABLED_ON_GPU(Convolve1D_1x2x5_1x2x2_Valid)) {
+ ComputationBuilder builder(client_, TestName());
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1}, Padding::kValid);
+ }
+
+ Array3D<float> input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}});
+ Array3D<float> filter({{{10, 20}, {30, 40}}});
+
+ Array3D<float> expected({{{510, 610, 710, 810}}});
+
+ auto input_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR3<float>(&builder, expected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+// TODO(b/32873825): implement 3D convolution on GPU.
+XLA_TEST_F(ConvolutionTest,
+ DISABLED_ON_GPU(Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid)) {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<int64> input_dims = {1, 4, 2, 3, 3};
+ std::vector<int64> filter_dims = {2, 2, 2, 3, 3};
+ Shape input_shape = ShapeUtil::MakeShape(F32, input_dims);
+ Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims);
+ {
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+
+ // Tensorflow dimension numbers for 3D convolution.
+ ConvolutionDimensionNumbers dnums;
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.add_spatial_dimensions(3);
+ dnums.set_feature_dimension(4);
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.add_kernel_spatial_dimensions(2);
+ dnums.set_kernel_input_feature_dimension(3);
+ dnums.set_kernel_output_feature_dimension(4);
+
+ builder.ConvWithGeneralDimensions(input, filter, {1, 1, 1}, Padding::kValid,
+ dnums);
+ }
+
+ std::vector<float> input_elems(ShapeUtil::ElementsIn(input_shape));
+ std::iota(input_elems.begin(), input_elems.end(), 1.0f);
+ auto input_r1 = LiteralUtil::CreateR1<float>(input_elems);
+ auto input_r5 =
+ LiteralUtil::Reshape(*input_r1, input_dims).ConsumeValueOrDie();
+
+ std::vector<float> filter_elems(ShapeUtil::ElementsIn(filter_shape));
+ std::iota(filter_elems.begin(), filter_elems.end(), 1.0f);
+ auto filter_r1 = LiteralUtil::CreateR1<float>(filter_elems);
+ auto filter_r5 =
+ LiteralUtil::Reshape(*filter_r1, filter_dims).ConsumeValueOrDie();
+
+ auto expected_r1 = LiteralUtil::CreateR1<float>(
+ {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446,
+ 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470});
+ auto expected_r5 =
+ LiteralUtil::Reshape(*expected_r1, {1, 3, 1, 2, 3}).ConsumeValueOrDie();
+
+ auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*filter_r5).ConsumeValueOrDie();
+
+ ComputeAndCompareLiteral(&builder, *expected_r5,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
new file mode 100644
index 0000000000..b599f9b95b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -0,0 +1,1294 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests of convolution variants -- kernel sizes, padding, and strides --
+// in small sized data.
+
+#include <algorithm>
+#include <initializer_list>
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ConvolutionVariantsTest : public ClientLibraryTestBase {
+ protected:
+#if XLA_TEST_BACKEND_GPU
+ // XLA:GPU sometimes uses FFT convolution which isn't as precise as spatial
+ // convolution. So relax the absolute error threshold.
+ ErrorSpec error_spec_ = ErrorSpec(1e-1, 1e-5);
+#else
+ ErrorSpec error_spec_ = ErrorSpec(1e-4, 1e-2);
+#endif
+};
+
+TEST_F(ConvolutionVariantsTest, Minimal) {
+ ComputationBuilder builder(client_, TestName());
+
+ const Array4D<float> input_array(1, 1, 1, 1, {2});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {3});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ const Array4D<float> expected(1, 1, 1, 1, {6});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
+ ComputationBuilder builder(client_, TestName());
+
+ const Array4D<float> input_array(5, 1, 1, 1, {1, 2, 3, 4, 5});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {2});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ const Array4D<float> expected(5, 1, 1, 1, {2, 4, 6, 8, 10});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Flat1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(2, 1, 3, 4);
+ input_array.FillWithMultiples(1);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {2.3});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(2, 1, 3, 4);
+ expected.FillWithMultiples(2.3);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Deep1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 2, 1, 1, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 3, 1, 1, {12, 34, 56});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 2, {1, 2});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 1, {12});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {12, 23});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 2, 1, {12, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 2, 1, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {13, 24});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 2, 2, {1000, 100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 1, {1234});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(
+ 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0
+ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 0, 0}); // plane 1
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(
+ 2, 2, 1, 2, {1000, 100, 10, 1, 0.1, 0.01, 0.001, 0.0001});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(
+ 2, 2, 2, 2,
+ {167, 1278, 3490, 4500, 0.0167, 0.1278, 0.3490, 0.4500, // plane 0
+ 334, 2556, 6980, 9000, 0.0334, 0.2556, 0.6980, 0.9000}); // plane 1
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {10});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {10, 30});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {10});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 3, {10, 30, 50});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 3, {100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 1, {123});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 3, {100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {123, 345});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 1, {10});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {2, 2}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 2, 2, {10, 30, 70, 90});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 1, {1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 3, {10, 20, 30});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 1, 1, {20});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 5, {10000, 1000, 100, 10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 1, 3, {123, 1230, 12300});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 3, 3, {10000, 0, 1000, // row 0
+ 0, 100, 0, // row 1
+ 10, 0, 1}); // row 2
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 2, 2, {104, 230, 2300, 10400});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 2, 1, 2, {1, 2, 3, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 2, 1, 1, {10, 1});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+
+ Array4D<float> expected(1, 1, 1, 2, {13, 24});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 2, 2, {7, 13, 17, 23});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 2, 2, {216, 276, 396, 456});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ const Array4D<float> filter_array(1, 1, 1, 2, {7, 13});
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 1, 1, 2, {33, 53});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(64);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(1, 1, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(128);
+ std::fill(filter_data.begin(), filter_data.begin() + 64, 1.0);
+ std::fill(filter_data.begin() + 64, filter_data.begin() + 128, 2.0);
+ const Array4D<float> filter_array(2, 1, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 2, 1, 1, {2016, 4032});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(16 * 1 * 1 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(16, 1, 1, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {1, 2, 3, 4, 5, 6, 7, 8,
+ 9, 10, 11, 12, 13, 14, 15, 16};
+ Array4D<float> expected(16, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ constexpr int bs = 16;
+ constexpr int kx = 2;
+ constexpr int ky = 2;
+ Array4D<float> input_array(bs, 1, ky, kx);
+ for (int i0 = 0; i0 < bs; ++i0) {
+ for (int i2 = 0; i2 < ky; ++i2) {
+ for (int i3 = 0; i3 < kx; ++i3) {
+ input_array(i0, 0, i2, i3) = i0 + 1;
+ }
+ }
+ }
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * ky * kx);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, ky, kx, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data(bs);
+ for (int i = 0; i < bs; ++i) {
+ expected_data[i] = 10 * (i + 1);
+ }
+ Array4D<float> expected(bs, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
+ ComputationBuilder builder(client_, TestName());
+
+ constexpr int kx = 2;
+ constexpr int ky = 2;
+ constexpr int bs = 3;
+ Array4D<float> input_array(bs, 1, ky, kx);
+ for (int i0 = 0; i0 < bs; ++i0) {
+ for (int i2 = 0; i2 < ky; ++i2) {
+ for (int i3 = 0; i3 < kx; ++i3) {
+ input_array(i0, 0, i2, i3) = i0 + i2 + i3 + 1;
+ }
+ }
+ }
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * ky * kx);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, ky, kx, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {
+ 23, 33, 43,
+ };
+ Array4D<float> expected(bs, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(16, 1, 8, 8);
+ for (int i0 = 0; i0 < 16; ++i0) {
+ for (int i2 = 0; i2 < 8; ++i2) {
+ for (int i3 = 0; i3 < 8; ++i3) {
+ input_array(i0, 0, i2, i3) = i0 + i2 + i3 + 1;
+ }
+ }
+ }
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 8 * 8);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ const Array4D<float> filter_array(1, 1, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {
+ 19664, 21744, 23824, 25904, 27984, 30064, 32144, 34224,
+ 36304, 38384, 40464, 42544, 44624, 46704, 48784, 50864,
+ };
+ Array4D<float> expected(16, 1, 1, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(2 * 8 * 8);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(1, 2, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(2 * 2 * 8 * 8);
+ std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
+ 1.0);
+ std::fill(filter_data.begin() + filter_data.size() / 4,
+ filter_data.begin() + filter_data.size() / 2, 2.0);
+ std::fill(filter_data.begin() + filter_data.size() / 2,
+ filter_data.begin() + 3 * filter_data.size() / 4, 3.0);
+ std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
+ 4.0);
+ const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(1, 2, 1, 1, {14240, 30496});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(2 * 2 * 8 * 8);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(2, 2, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(2 * 2 * 8 * 8);
+ std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
+ 1.0);
+ std::fill(filter_data.begin() + filter_data.size() / 4,
+ filter_data.begin() + filter_data.size() / 2, 2.0);
+ std::fill(filter_data.begin() + filter_data.size() / 2,
+ filter_data.begin() + 3 * filter_data.size() / 4, 3.0);
+ std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
+ 4.0);
+ const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(2, 2, 1, 1, {14240, 30496, 38816, 87840});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(32 * 2 * 8 * 8);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(32, 2, 8, 8, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(2 * 2 * 8 * 8);
+ std::fill(filter_data.begin(), filter_data.begin() + filter_data.size() / 4,
+ 1.0);
+ std::fill(filter_data.begin() + filter_data.size() / 4,
+ filter_data.begin() + filter_data.size() / 2, 2.0);
+ std::fill(filter_data.begin() + filter_data.size() / 2,
+ filter_data.begin() + 3 * filter_data.size() / 4, 3.0);
+ std::fill(filter_data.begin() + 3 * filter_data.size() / 4, filter_data.end(),
+ 4.0);
+ const Array4D<float> filter_array(2, 2, 8, 8, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::vector<float> expected_data = {
+ 14240, 30496, 38816, 87840, 63392, 145184, 87968,
+ 202528, 112544, 259872, 137120, 317216, 161696, 374560,
+ 186272, 431904, 210848, 489248, 235424, 546592, 260000,
+ 603936, 284576, 661280, 309152, 718624, 333728, 775968,
+ 358304, 833312, 382880, 890656, 407456, 948000, 432032,
+ 1005344, 456608, 1062688, 481184, 1120032, 505760, 1177376,
+ 530336, 1.23472e+06, 554912, 1292064, 579488, 1349408, 604064,
+ 1406752, 628640, 1464096, 653216, 1.52144e+06, 677792, 1578784,
+ 702368, 1636128, 726944, 1693472, 751520, 1750816, 776096,
+ 1.80816e+06,
+ };
+ Array4D<float> expected(32, 2, 1, 1, expected_data);
+ // The output elements can be larger than 1e+5, making the absolute error
+ // large sometimes. So, we focus on relative errors for this test case.
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
+ ComputationBuilder builder(client_, TestName());
+
+ Array4D<float> input_array(16, 16, 1, 1);
+ Array4D<float> filter_array(16, 16, 1, 1);
+ for (int i0 = 0; i0 < 16; ++i0) {
+ for (int i1 = 0; i1 < 16; ++i1) {
+ input_array(i0, i1, 0, 0) = 1000 * i0 + i1;
+ filter_array(i0, i1, 0, 0) = 1;
+ }
+ }
+
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ Array4D<float> expected(16, 16, 1, 1);
+ for (int i0 = 0; i0 < 16; ++i0) {
+ for (int i1 = 0; i1 < 16; ++i1) {
+ expected(i0, i1, 0, 0) = 16000 * i0 + 120;
+ }
+ }
+
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 4 * 6);
+ std::iota(input_data.begin(), input_data.end(), 0.0);
+ Array4D<float> input_array(1, 1, 4, 6, input_data);
+
+ Array4D<float> filter_array(1, 1, 2, 3, {1, 10, 100, 2, 20, 200});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 2, 2, {3924, 4257, 5922, 6255});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
+ /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 3 * 4);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 3, 4, input_data);
+
+ Array4D<float> filter_array(1, 1, 4, 3, {100, 10, 1, //
+ 200, 20, 2, //
+ 300, 30, 3, //
+ 400, 40, 4});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1},
+ /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2},
+ /*rhs_dilation=*/{},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 3, 5, {204, 40, 406, 60, 608, //
+ 1518, 180, 1821, 210, 2124, //
+ 4146, 460, 4651, 510, 5156});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneral(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {-1, -1}},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 2, {23, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneral(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {-1, 2}},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 5, {23, 34, 45, 50, 0});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneral(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {2, -1}},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ Array4D<float> expected(1, 1, 1, 5, {0, 1, 12, 23, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {3, 2}},
+ /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ // input:
+ // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5]
+ // ---pad---> [0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 0]
+ // filter:
+ // [10, 1] --dilate-> [10, 0, 1]
+ Array4D<float> expected(1, 1, 1, 12,
+ {0, 1, 0, 12, 0, 23, 0, 34, 0, 45, 0, 50});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 1 * 1 * 5);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 1, 1, 5, input_data);
+
+ Array4D<float> filter_array(1, 1, 1, 2, {10, 1});
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.ConvGeneralDilated(
+ /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
+ /*padding=*/{{0, 0}, {-3, -2}},
+ /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+
+ // input:
+ // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5]
+ // ---pad---> [0, 3, 0, 4]
+ // filter:
+ // [10, 1] --dilate-> [10, 0, 1]
+ Array4D<float> expected(1, 1, 1, 2, {0, 34});
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
+ constexpr int bs = 1;
+ constexpr int iz = 1;
+ constexpr int oz = 2;
+ constexpr int iy = 2;
+ constexpr int ix = 3;
+ constexpr int ky = 1;
+ constexpr int kx = 2;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
+ constexpr int bs = 1;
+ constexpr int iz = 16;
+ constexpr int oz = 1;
+ constexpr int iy = 1;
+ constexpr int ix = 1;
+ constexpr int ky = 1;
+ constexpr int kx = 1;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
+ constexpr int bs = 16;
+ constexpr int iz = 16;
+ constexpr int oz = 1;
+ constexpr int iy = 1;
+ constexpr int ix = 1;
+ constexpr int ky = 1;
+ constexpr int kx = 1;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
+ constexpr int bs = 16;
+ constexpr int iz = 16;
+ constexpr int oz = 16;
+ constexpr int iy = 1;
+ constexpr int ix = 1;
+ constexpr int ky = 1;
+ constexpr int kx = 1;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x16x16_Filter16x16x16x16) {
+ constexpr int bs = 16;
+ constexpr int iz = 16;
+ constexpr int oz = 16;
+ constexpr int iy = 16;
+ constexpr int ix = 16;
+ constexpr int ky = 16;
+ constexpr int kx = 16;
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<float> input_data(bs * iz * iy * ix);
+ for (float& f : input_data) {
+ f = distribution(rng);
+ }
+ std::vector<float> kernel_data(oz * iz * ky * kx);
+ for (float& f : kernel_data) {
+ f = distribution(rng);
+ }
+
+ Array4D<float> input_array(bs, iz, iy, ix, input_data);
+ Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+
+ std::unique_ptr<Array4D<float>> expected = ReferenceUtil::ConvArray4D(
+ input_array, filter_array, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 2 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 1.0);
+ Array4D<float> filter_array(1, 2, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests padding sizes that don't correspond either to SAME or VALID padding.
+ builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
+
+ std::vector<float> expected_data = {
+ 0, 0, 0, 0, 0, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0, //
+ 0, 2, 5, 8, 3, 0, 0, //
+ 0, 8, 14, 17, 6, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0 //
+ };
+ Array4D<float> expected(1, 5, 7, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 2.0);
+ Array4D<float> filter_array(1, 1, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests padding sizes that don't correspond either to SAME or VALID padding.
+ builder.ConvGeneral(input, filter, {1, 1}, {{2, 1}, {2, 3}}, dnums);
+
+ std::vector<float> expected_data = {
+ 0, 0, 0, 0, 0, 0, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0, 0, //
+ 0, 0, 2, 4, 6, 0, 0, 0, //
+ 0, 0, 8, 10, 12, 0, 0, 0, //
+ 0, 0, 0, 0, 0, 0, 0, 0 //
+ };
+ Array4D<float> expected(1, 5, 8, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 1);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 1, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 1 * 1);
+ std::iota(filter_data.begin(), filter_data.end(), 2.0);
+ Array4D<float> filter_array(1, 1, 1, 1, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests zero padding sizes. This can use matmul for computation.
+ builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
+
+ std::vector<float> expected_data = {
+ 2, 4, 6, //
+ 8, 10, 12,
+ };
+ Array4D<float> expected(1, 2, 3, 1, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::vector<float> input_data(1 * 2 * 3 * 2);
+ std::iota(input_data.begin(), input_data.end(), 1.0);
+ Array4D<float> input_array(1, 2, 3, 2, input_data);
+ auto input = builder.ConstantR4FromArray4D<float>(input_array);
+
+ std::vector<float> filter_data(1 * 1 * 2 * 3);
+ std::iota(filter_data.begin(), filter_data.end(), 2.0);
+ Array4D<float> filter_array(1, 1, 2, 3, filter_data);
+ auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
+
+ ConvolutionDimensionNumbers dnums;
+ // NHWC input format.
+ dnums.set_batch_dimension(0);
+ dnums.add_spatial_dimensions(1);
+ dnums.add_spatial_dimensions(2);
+ dnums.set_feature_dimension(3);
+
+ // Tensorflow filter shape: [ H, W, inC, outC ]
+ dnums.add_kernel_spatial_dimensions(0);
+ dnums.add_kernel_spatial_dimensions(1);
+ dnums.set_kernel_input_feature_dimension(2);
+ dnums.set_kernel_output_feature_dimension(3);
+
+ // Tests zero padding sizes. This can use matmul for computation.
+ builder.ConvGeneral(input, filter, {1, 1}, {{0, 0}, {0, 0}}, dnums);
+
+ std::vector<float> expected_data = {
+ 12, 15, 18, //
+ 26, 33, 40, //
+ 40, 51, 62, //
+ 54, 69, 84, //
+ 68, 87, 106, //
+ 82, 105, 128, //
+ };
+ Array4D<float> expected(1, 2, 3, 3, expected_data);
+ ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
+}
+
+// Regression test for b/32034796.
+//
+// XLA:GPU fuses
+// Conv([1,2,3], Reverse([5,6]), padding_low=1)
+// into
+// BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1)
+TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 2, /*values=*/{5, 6}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 0}});
+ ComputeAndCompareR4<float>(&builder, {{{{5, 16, 27}}}}, {}, error_spec_);
+}
+
+// XLA:GPU fuses
+// Conv([1], Reverse([1,10,100]), padding_high=3, base_dilation=3)
+// into
+// BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1))
+TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvGeneralDilated(
+ gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 3}},
+ /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ ComputeAndCompareR4<float>(&builder, {{{{100, 0}}}}, {}, error_spec_);
+}
+
+// XLA:GPU fuses
+// Conv([1], Reverse([1,10,100]), padding=(1,1))
+// into
+// BackwardInputConv([1], [1,10,100], padding=(1,1))
+TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 1}});
+ ComputeAndCompareR4<float>(&builder, {{{{10}}}}, {}, error_spec_);
+}
+
+// HLO pattern
+// Conv([1,2,3], Reverse([1,10], padding_high=2)
+// could be fused to
+// BackwardInputConv([1,2,3], [1,10], padding_low=1, padding_high=-1)
+//
+// However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't
+// support negative padding on backward convolution yet (b/32744257).
+TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
+ auto weights = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 2, /*values=*/{1, 10}));
+ auto mirrored_weights = builder.Rev(weights, {2, 3});
+ builder.ConvWithGeneralPadding(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 2}});
+
+ ComputeAndCompareR4<float>(&builder, {{{{12, 23, 30, 0}}}}, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0
+ // gradients: 100,10,1 -dilate-> 100,0,10,0,1
+ // weight gradients: 24,130,240
+ //
+ // This pattern will be fused to backward convolution with padding=(1,2).
+ auto activations = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv = builder.ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {1, 2}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.Transpose(forward_conv, {0, 1, 2, 3});
+
+ ComputeAndCompareR4<float>(&builder, {{{{24, 130, 240}}}}, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest,
+ BackwardFilterLowPaddingGreaterThanHighPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4
+ // gradients: 100,10,1 -dilate-> 100,0,10,0,1
+ // weight gradients: 13,24
+ //
+ // This pattern will be fused to backward convolution with padding=(2,1).
+ // Note: both (2,1) and (2,0) are valid padding for the backward convolution
+ // because the stride is 2.
+ auto activations = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv = builder.ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {2, 0}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.Transpose(forward_conv, {0, 1, 2, 3});
+
+ ComputeAndCompareR4<float>(&builder, {{{{13, 24}}}}, {}, error_spec_);
+}
+
+TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
+ ComputationBuilder builder(client_, TestName());
+
+ // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0
+ // gradients: 100,10,1 -dilate-> 100,0,10,0,1
+ // weight gradients: 13,24,130
+ //
+ // This pattern will be fused to backward convolution with padding=(2,2).
+ // Note: both (2,1) and (2,2) are valid padding for the backward convolution
+ // because the stride is 2. ConvolutionFolding prefers (2,2) because cuDNN
+ // supports even padding only -- using (2,1) would need extra effort of
+ // canonicalization.
+ auto activations = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 4, /*values=*/{1, 2, 3, 4}));
+ auto gradients = builder.ConstantR4FromArray4D<float>(
+ Array4D<float>(1, 1, 1, 3, /*values=*/{100, 10, 1}));
+ auto forward_conv = builder.ConvGeneralDilated(
+ activations, gradients,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {2, 1}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
+ ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.Transpose(forward_conv, {0, 1, 2, 3});
+
+ ComputeAndCompareR4<float>(&builder, {{{{13, 24, 130}}}}, {}, error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc
new file mode 100644
index 0000000000..29e2950533
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/copy_test.cc
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class CopyOpTest : public HloTestBase {
+ protected:
+ void TestCopyOp(const Literal& literal) {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(MakeUnique<Literal>(literal)));
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+ auto computation = builder.Build();
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectEqual(literal, *result);
+ }
+
+ void TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3);
+ void TestCopyConstantLayoutR4(size_t n1, size_t n2, size_t n3, size_t n4,
+ tensorflow::gtl::ArraySlice<int64> permutation);
+};
+
+TEST_F(CopyOpTest, CopyR0Bool) {
+ TestCopyOp(*LiteralUtil::CreateR0<bool>(true));
+}
+
+TEST_F(CopyOpTest, CopyR1S0U32) {
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({}));
+}
+
+TEST_F(CopyOpTest, CopyR1S3U32) {
+ TestCopyOp(*LiteralUtil::CreateR1<uint32>({1, 2, 3}));
+}
+
+TEST_F(CopyOpTest, CopyR3F32_2x2x3) {
+ TestCopyOp(
+ *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}},
+ {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}}));
+}
+
+TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) {
+ TestCopyOp(*LiteralUtil::CreateR4(
+ {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}},
+ {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}}));
+}
+
+TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) {
+ TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D<int32>(0, 2, 3, 2)));
+}
+
+TEST_F(CopyOpTest, CopyParameterScalar) {
+ auto builder = HloComputation::Builder(TestName());
+
+ // Copy literal to device to use as parameter.
+ auto literal = LiteralUtil::CreateR0<float>(42.0);
+ Shape shape = literal->shape();
+ auto constant_device_base = TransferToDevice(*literal);
+
+ auto param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param0"));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
+
+ auto computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {constant_device_base});
+ LiteralTestUtil::ExpectR0Near<float>(42.0f, *result, error_spec_);
+}
+
+TEST_F(CopyOpTest, CopyConstantR2Twice) {
+ auto builder = HloComputation::Builder(TestName());
+
+ auto literal = LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ auto copy = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+ builder.AddInstruction(
+ HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, copy));
+
+ auto computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 2.0}, {3.0, 4.0}}, *result,
+ error_spec_);
+}
+
+TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) {
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ // Reverse the minor-to-major order of the literal.
+ Layout* literal_layout = literal->mutable_shape()->mutable_layout();
+ ASSERT_EQ(2, literal_layout->minor_to_major_size());
+ literal_layout->mutable_minor_to_major()->SwapElements(0, 1);
+
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+
+ // The result of the computation has the default layout, which is the inverse
+ // of the layout of the source literal.
+ LiteralTestUtil::ExpectR2Near<float>({{1.0, 3.0}, {2.0, 4.0}}, *result,
+ error_spec_);
+}
+
+void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) {
+ Array3D<int32> a(n1, n2, n3);
+ for (size_t i = 0; i < n1; ++i) {
+ for (size_t j = 0; j < n2; ++j) {
+ for (size_t k = 0; k < n3; ++k) {
+ a(i, j, k) = i * n3 * n2 + j * n3 + k;
+ }
+ }
+ }
+
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR3FromArray3D(a);
+
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto config = MakeUnique<HloModuleConfig>(computation->ComputeProgramShape());
+ *config->mutable_entry_computation_layout()->mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(
+ constant->shape().element_type(),
+ AsInt64Slice(constant->shape().dimensions()), {1, 2, 0}));
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), std::move(config), {});
+
+ LiteralTestUtil::ExpectR3EqualArray3D(a, *result);
+}
+
+void CopyOpTest::TestCopyConstantLayoutR4(
+ size_t n1, size_t n2, size_t n3, size_t n4,
+ tensorflow::gtl::ArraySlice<int64> permutation) {
+ Array4D<int32> a(n1, n2, n3, n4);
+ for (size_t i = 0; i < n1; ++i) {
+ for (size_t j = 0; j < n2; ++j) {
+ for (size_t k = 0; k < n3; ++k) {
+ for (size_t l = 0; l < n4; ++l) {
+ a(i, j, k, l) = i * n4 * n3 * n2 + j * n4 * n3 + k * n4 + l;
+ }
+ }
+ }
+ }
+
+ HloComputation::Builder builder(TestName());
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR4FromArray4D(a);
+
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(std::move(literal)));
+
+ builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kCopy, constant));
+
+ std::unique_ptr<HloComputation> computation = builder.Build();
+
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto config = MakeUnique<HloModuleConfig>(computation->ComputeProgramShape());
+ *config->mutable_entry_computation_layout()->mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(
+ constant->shape().element_type(),
+ AsInt64Slice(constant->shape().dimensions()), ({
+ std::vector<int64> p(permutation.rbegin(), permutation.rend());
+ p;
+ })));
+ hlo_module->AddEntryComputation(std::move(computation));
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), std::move(config), {});
+
+ LiteralTestUtil::ExpectR4EqualArray4D(a, *result);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleIncompleteTilePerLayer) {
+ TestCopyConstantLayout021(2, 2, 3);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_SingleCompleteTilePerLayer) {
+ TestCopyConstantLayout021(2, 32, 32);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR3Layout021_MultipleTilesPerLayer) {
+ TestCopyConstantLayout021(2, 70, 35);
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0231_MultipleTilesPerLayer) {
+ TestCopyConstantLayoutR4(2, 70, 7, 5, {0, 2, 3, 1});
+}
+
+XLA_TEST_F(CopyOpTest, CopyConstantR4Layout0312_MultipleTilesPerLayer) {
+ TestCopyConstantLayoutR4(2, 14, 5, 35, {0, 3, 1, 2});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
new file mode 100644
index 0000000000..dc54c9defe
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -0,0 +1,148 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <utility>
+
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/test.h"
+
+extern "C" void __attribute__((visibility("default")))
+R0F32Add2(float* out, float** in) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*));
+ *out = **in + 2.0f;
+}
+
+extern "C" void __attribute__((visibility("default")))
+R2F32ReduceSum(float* out, float** in) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
+ float* array = in[0];
+ *out = array[0] + array[1] + array[2] + array[3];
+}
+
+extern "C" void __attribute__((visibility("default")))
+Add1ToValues(float* out, float** in) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4);
+ float* array = in[0];
+ out[0] = array[0] + 1;
+ out[1] = array[1] + 1;
+ out[2] = array[2] + 1;
+ out[3] = array[3] + 1;
+}
+
+namespace xla {
+namespace {
+
+class CustomCallTest : public HloTestBase {
+ protected:
+ Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
+ Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
+};
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2"));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR0Near<float>(44.0f, *result, error_spec_);
+}
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto builder = HloComputation::Builder(TestName());
+
+ Array2D<float> array(2, 2);
+ array(0, 0) = 1.0f;
+ array(0, 1) = 2.0f;
+ array(1, 0) = 3.0f;
+ array(1, 1) = 4.0f;
+
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array)));
+ builder.AddInstruction(
+ HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum"));
+
+ hlo_module->AddEntryComputation(builder.Build());
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR0Near<float>(10.0f, *result, error_spec_);
+}
+
+XLA_TEST_F(CustomCallTest,
+ DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) {
+ auto hlo_module = MakeUnique<HloModule>("test_module");
+ auto b = HloComputation::Builder(TestName());
+
+ auto input = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(
+ Array2D<float>{{1.0f, 2.0f}, {3.0f, 4.0f}})));
+ auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues"));
+ auto incremented_again = b.AddInstruction(HloInstruction::CreateCustomCall(
+ ShapeUtil::MakeShape(F32, {1, 2, 2}), {incremented}, "Add1ToValues"));
+
+ // Concatenate the values along first dim.
+ b.AddInstruction(
+ HloInstruction::CreateConcatenate(ShapeUtil::MakeShape(F32, {2, 2, 2}),
+ {incremented, incremented_again}, 0));
+
+ hlo_module->AddEntryComputation(b.Build());
+
+ std::unique_ptr<Literal> result =
+ ExecuteAndTransfer(std::move(hlo_module), {});
+ LiteralTestUtil::ExpectR3EqualArray3D<float>(
+ Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc
new file mode 100644
index 0000000000..528efd2942
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/deallocation_test.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class DeallocationTest : public ClientLibraryTestBase {
+ protected:
+ // Build and execute the given computation then verify the results can be
+ // transferred from the device successfully.
+ std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ Computation computation = builder->Build().ConsumeValueOrDie();
+ auto global_data =
+ client_->Execute(computation, arguments).ConsumeValueOrDie();
+ TF_CHECK_OK(client_->Transfer(*global_data).status());
+ return global_data;
+ }
+};
+
+TEST_F(DeallocationTest, DeallocateScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR0<float>(42.0);
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ // A result can be transfered an arbitrary number of times. Add an extra
+ // transfer here so we're not just testing that a second call to Transfer
+ // fails.
+ ASSERT_IS_OK(client_->Transfer(*global_data).status());
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+TEST_F(DeallocationTest, DeallocateVector) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+TEST_F(DeallocationTest, DeallocateEmptyVector) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+XLA_TEST_F(DeallocationTest, DeallocateTuple) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tuple({builder.ConstantR0<float>(42.0),
+ builder.ConstantR1<float>({1.0, 2.0, 3.0})});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+XLA_TEST_F(DeallocationTest, DeallocateTupleWithRepeatedElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto element = builder.ConstantR0<float>(42.0);
+ auto inner_tuple = builder.Tuple({builder.ConstantR0<float>(42.0), element});
+ builder.Tuple({element, inner_tuple, element});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) {
+ ComputationBuilder builder(client_, TestName());
+ auto inner_tuple =
+ builder.Tuple({builder.ConstantR0<float>(42.0),
+ builder.ConstantR1<float>({1.0, 2.0, 3.0})});
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({0.123, 0.456})});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ ASSERT_IS_OK(client_->Unregister(*global_data));
+
+ auto transfer_status = client_->Transfer(*global_data);
+ ASSERT_FALSE(transfer_status.ok());
+ ASSERT_MATCH(transfer_status.status().error_message(),
+ testing::HasSubstr("was previously deallocated"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
new file mode 100644
index 0000000000..57a7c61b14
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc
@@ -0,0 +1,215 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class DeconstructTupleTest : public ClientLibraryTestBase {
+ protected:
+ // Build and execute the given computation then verify the results can be
+ // transferred from the device successfully.
+ std::unique_ptr<GlobalData> ExecuteAndCheckTransfer(
+ ComputationBuilder* builder,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ Computation computation = builder->Build().ConsumeValueOrDie();
+ auto global_data =
+ client_->Execute(computation, arguments).ConsumeValueOrDie();
+ TF_CHECK_OK(client_->Transfer(*global_data).status());
+ return global_data;
+ }
+};
+
+TEST_F(DeconstructTupleTest, DeconstructTuple) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+
+ // Try copying the elements back and comparing it
+ auto handles = result_status.ConsumeValueOrDie();
+ std::vector<float> copy(4);
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+}
+
+TEST_F(DeconstructTupleTest, DeconstructTupleTwice) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status1 = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status1.ok());
+ auto result_status2 = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status2.ok());
+
+ auto handles1 = result_status1.ConsumeValueOrDie();
+ auto handles2 = result_status2.ConsumeValueOrDie();
+ std::vector<float> copy(4);
+
+ ASSERT_IS_OK(client_->TransferInProcess(*handles1[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles1[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ handles1[0].reset();
+ handles1[1].reset();
+
+ ASSERT_IS_OK(client_->TransferInProcess(*handles2[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles2[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+}
+
+XLA_TEST_F(DeconstructTupleTest, DeconstructTupleRepeatedElement) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2, const2, const1});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+
+ // Verify the returned GlobalDataHandle arrays have repeated elements like the
+ // tuple does. That is, in the returned vector of handles, handle[0] should be
+ // the same as handle[3] and handle[1] should be the same as handle[2].
+ auto handles = result_status.ConsumeValueOrDie();
+
+ std::vector<float> copy(4);
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[2], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[3], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+}
+
+TEST_F(DeconstructTupleTest, DeconstructTupleThenDeallocate) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({const1, const2, const1});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+ auto handles = result_status.ConsumeValueOrDie();
+
+ // Deallocate the tuple, then try copying the elements back. The elements
+ // should not have been deallocated because of reference counting.
+ global_data.reset();
+
+ std::vector<float> copy(4);
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[0], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[1], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({2.0, 4.0, 6.0, 8.0}));
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[2], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+
+ /// Try deallocating one of the repeated elements, then copy
+ handles[0].reset();
+
+ ASSERT_IS_OK(client_->TransferInProcess(*handles[2], &copy[0]));
+ EXPECT_MATCH(copy, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+}
+
+TEST_F(DeconstructTupleTest, DeconstructNonTuple) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(result_status.status().error_message(),
+ testing::ContainsRegex("global data handle .* is not a tuple"));
+}
+
+XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
+ builder.Tuple({p});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {param0_data.get()});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_TRUE(result_status.ok());
+ auto handles = result_status.ConsumeValueOrDie();
+ EXPECT_NE(handles[0]->handle().handle(), param0_data->handle().handle());
+}
+
+XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) {
+ ComputationBuilder builder(client_, TestName());
+ auto const1 = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto const2 = builder.ConstantR1<float>({2.0, 4.0, 6.0, 8.0});
+ builder.Tuple({builder.Tuple({const1, const2}), const1});
+ auto global_data = ExecuteAndCheckTransfer(&builder, {});
+
+ auto result_status = client_->DeconstructTuple(*global_data);
+ EXPECT_FALSE(result_status.ok());
+ EXPECT_MATCH(
+ result_status.status().error_message(),
+ testing::ContainsRegex("deconstructing nested tuples not yet supported"));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
new file mode 100644
index 0000000000..da2d43ca4f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -0,0 +1,387 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace {
+
+// TODO(mfdyck): use GUnit typed tests when we can do all tests on all backends.
+class DotOperationTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001, 1e-5};
+
+ protected:
+ template <typename Element>
+ void TestOneElementVectorDot();
+ template <typename Element>
+ void TestVectorDot();
+ template <typename Element>
+ void TestSquareMatrixDot(bool lhs_row_major = false,
+ bool rhs_row_major = false);
+ template <typename Element>
+ void TestNonsquareMatrixDot(bool lhs_row_major = false,
+ bool rhs_row_major = false);
+};
+
+XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<float>({});
+ auto rhs = builder.ConstantR1<float>({});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
+}
+
+template <typename Element>
+void DotOperationTest::TestOneElementVectorDot() {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<Element>({2.0});
+ auto rhs = builder.ConstantR1<Element>({3.0});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR0<Element>(&builder, 6.0, {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, OneElementVectorDotF32) {
+ TestOneElementVectorDot<float>();
+}
+
+XLA_TEST_F(DotOperationTest, OneElementVectorDotF64) {
+ TestOneElementVectorDot<double>();
+}
+
+template <typename Element>
+void DotOperationTest::TestVectorDot() {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR1<Element>({1.0, 2.5, 42.0});
+ auto rhs = builder.ConstantR1<Element>({11.0, -1.0, 0.5});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR0<Element>(&builder, 29.5, {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, VectorDotF32) { TestVectorDot<float>(); }
+
+XLA_TEST_F(DotOperationTest, VectorDotF64) { TestVectorDot<double>(); }
+
+namespace {
+
+std::vector<int64> MinorToMajorForIsRowMajor(bool row_major) {
+ return {row_major ? 1 : 0, row_major ? 0 : 1};
+}
+
+} // namespace
+
+XLA_TEST_F(DotOperationTest, Dot_0x2_2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, Dot_0x2_2x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto rhs = builder.ConstantR2<float>({{7.0, 8.0, 9.0}, {42.0, 77.0, 101.0}});
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 3), {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, Dot_3x2_2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs =
+ builder.ConstantR2<float>({{7.0, 8.0}, {9.0, 42.0}, {77.0, 101.0}});
+ auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, Dot_2x0_0x2) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(2, 0));
+ auto rhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto result = builder.Dot(lhs, rhs);
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 2, 0.0f), {},
+ error_spec_);
+}
+
+template <typename Element>
+void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major,
+ bool rhs_row_major) {
+ auto lhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 2.0}, {3.0, -4.0}},
+ MinorToMajorForIsRowMajor(lhs_row_major)))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 6.0}, {7.0, -4.0}},
+ MinorToMajorForIsRowMajor(rhs_row_major)))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
+ auto result = builder.Dot(
+ builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), "lhs"),
+ builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 2}), "rhs"));
+
+ Array2D<Element> expected({{15.0, -2.0}, {-25.0, 34.0}});
+ ComputeAndCompareR2<Element>(
+ &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) {
+ constexpr bool kLhsRowMajor = false;
+ constexpr bool kRhsRowMajor = false;
+ TestSquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) {
+ TestSquareMatrixDot<float>(false, true);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) {
+ TestSquareMatrixDot<float>(true, false);
+}
+
+TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) {
+ constexpr bool kLhsRowMajor = true;
+ constexpr bool kRhsRowMajor = true;
+ TestSquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) {
+ TestSquareMatrixDot<double>();
+}
+
+template <typename Element>
+void DotOperationTest::TestNonsquareMatrixDot(bool lhs_row_major,
+ bool rhs_row_major) {
+ auto lhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}},
+ MinorToMajorForIsRowMajor(lhs_row_major)))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(*test_utils::CreateR2LiteralWithLayout<Element>(
+ {{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}},
+ MinorToMajorForIsRowMajor(rhs_row_major)))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<Element>();
+ auto result = builder.Dot(
+ builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 3}), "lhs"),
+ builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), "rhs"));
+
+ Array2D<Element> expected({{26.0, 0.0}, {-12.0, 10.0}});
+
+ ComputeAndCompareR2<Element>(
+ &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFF) {
+ constexpr bool kLhsRowMajor = false;
+ constexpr bool kRhsRowMajor = false;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorFT) {
+ constexpr bool kLhsRowMajor = false;
+ constexpr bool kRhsRowMajor = true;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) {
+ constexpr bool kLhsRowMajor = true;
+ constexpr bool kRhsRowMajor = false;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) {
+ constexpr bool kLhsRowMajor = true;
+ constexpr bool kRhsRowMajor = true;
+ TestNonsquareMatrixDot<float>(kLhsRowMajor, kRhsRowMajor);
+}
+
+XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) {
+ TestNonsquareMatrixDot<double>();
+}
+
+TEST_F(DotOperationTest, ConcurrentMatMul) {
+ ComputationBuilder builder(client_, TestName());
+ auto matrix1 = builder.ConstantR2<float>({{1.0, 2.0}, {3.0, 4.0}});
+ auto matrix2 = builder.ConstantR2<float>({{5.0, 6.0}, {7.0, 8.0}});
+ auto matrix12 = builder.Dot(matrix1, matrix2);
+ auto matrix21 = builder.Dot(matrix2, matrix1);
+ builder.Add(matrix12, matrix21);
+
+ Array2D<float> expected({{42.0, 56.0}, {74.0, 96.0}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// Regression test for b/32055648. The root of the graph is a kFusion of 4
+// bitcasts. Although bitcasts don't map to thunks, the root should still be
+// sync-dependent on bitcasts' operands.
+XLA_TEST_F(DotOperationTest, BatchMatMul) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "y");
+
+ auto x_flat = builder.Reshape(x, {0, 1, 2, 3}, {4, 2, 2});
+ auto y_flat = builder.Reshape(y, {0, 1, 2, 3}, {4, 2, 2});
+
+ // Slice batches into individual matrices and multiply them.
+ std::vector<xla::ComputationDataHandle> out_slices;
+ for (int i = 0; i < 4; ++i) {
+ // Slice off individual matrices and reshape to 2D tensors.
+ auto x_slice = builder.Slice(x_flat, {i, 0, 0}, {i + 1, 2, 2});
+ x_slice = builder.Reshape(x_slice, {0, 1, 2}, {2, 2});
+ auto y_slice = builder.Slice(y_flat, {i, 0, 0}, {i + 1, 2, 2});
+ y_slice = builder.Reshape(y_slice, {0, 1, 2}, {2, 2});
+
+ auto out = builder.Dot(x_slice, y_slice);
+ out = builder.Reshape(out, {0, 1}, {1, 2, 2});
+ out_slices.push_back(out);
+ }
+ auto out_flat = builder.ConcatInDim(out_slices, 0);
+ builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2});
+
+ auto x_data = client_
+ ->TransferToServer(*LiteralUtil::CreateR4<float>(
+ {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}},
+ {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}}))
+ .ConsumeValueOrDie();
+ auto y_data = client_
+ ->TransferToServer(*LiteralUtil::CreateR4<float>(
+ {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}},
+ {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}}))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(
+ &builder,
+ /*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}},
+ {{{42900, 79200}, {429, 792}},
+ {{250800, 299200}, {2508, 2992}}}},
+ {x_data.get(), y_data.get()}, error_spec_);
+}
+
+TEST_F(DotOperationTest, TransposeFolding) {
+ for (bool transpose_lhs : {false, true}) {
+ for (bool transpose_rhs : {false, true}) {
+ for (bool row_major : {false, true}) {
+ std::unique_ptr<Array2D<float>> lhs(
+ new Array2D<float>({{1.0, 2.0, 3.0}, {3.0, -4.0, -1.0}}));
+ std::unique_ptr<Array2D<float>> rhs(
+ new Array2D<float>({{1.0, 6.0}, {2.0, 3.0}, {7.0, -4.0}}));
+
+ if (transpose_lhs) {
+ lhs = ReferenceUtil::TransposeArray2D(*lhs);
+ }
+ if (transpose_rhs) {
+ rhs = ReferenceUtil::TransposeArray2D(*rhs);
+ }
+ auto lhs_handle =
+ client_
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<float>(
+ *lhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
+ .ConsumeValueOrDie();
+ auto rhs_handle =
+ client_
+ ->TransferToServer(
+ *LiteralUtil::CreateR2FromArray2DWithLayout<float>(
+ *rhs, LayoutUtil::MakeLayout(
+ MinorToMajorForIsRowMajor(row_major))))
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto prim_type = primitive_util::NativeToPrimitiveType<float>();
+ auto lhs_arg = builder.Parameter(
+ 0, ShapeUtil::MakeShape(prim_type, {lhs->height(), lhs->width()}),
+ "lhs");
+ auto rhs_arg = builder.Parameter(
+ 1, ShapeUtil::MakeShape(prim_type, {rhs->height(), rhs->width()}),
+ "rhs");
+ if (transpose_lhs) {
+ lhs_arg = builder.Transpose(lhs_arg, {1, 0});
+ }
+ if (transpose_rhs) {
+ rhs_arg = builder.Transpose(rhs_arg, {1, 0});
+ }
+ auto result = builder.Dot(lhs_arg, rhs_arg);
+
+ Array2D<float> expected({{26.0, 0.0}, {-12.0, 10.0}});
+ VLOG(1) << "TestTransposeFolding " << transpose_lhs << " "
+ << transpose_rhs << " " << row_major;
+ ComputeAndCompareR2<float>(&builder, expected,
+ {lhs_handle.get(), rhs_handle.get()},
+ error_spec_);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendLayoutUtilFlags(&flag_list);
+ xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list);
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
new file mode 100644
index 0000000000..cecc4872df
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+
+class DynamicSliceTest : public ClientLibraryTestBase {
+ protected:
+ template <typename IndexT>
+ void TestR1() {
+ // Slice at dimension start.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {0}, {5},
+ {0.0, 1.0, 2.0, 3.0, 4.0});
+ // Slice in the middle.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {2}, {3},
+ {2.0, 3.0, 4.0});
+ // Slice at dimension boundaries.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {5}, {3},
+ {5.0, 6.0, 7.0});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0}, {6}, {4},
+ {6.0, 7.0, 0.0, 1.0});
+ }
+
+ template <typename IndexT>
+ void TestR2() {
+ // Slice at dimension start.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {0, 0}, {2, 2}, {{1.0f, 2.0f}, {4.0f, 5.0f}});
+ // Slice in the middle.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {1, 1}, {2, 1}, {{5.0f}, {8.0f}});
+ // Slice at dimension boundaries.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {1, 1}, {2, 1}, {{5.0f}, {8.0f}});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR2<IndexT>({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {1, 1}, {3, 3},
+ {{5.0f, 6.0f, 4.0f}, {8.0f, 9.0f, 7.0f}, {2.0f, 3.0f, 1.0f}});
+ }
+
+ template <typename IndexT>
+ void TestR3() {
+ // R3 Shape: [2, 3, 2]
+ // clang-format off
+
+ // Slice at dimension start.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {0, 0, 0}, {2, 1, 2},
+ {{{1.0f, 2.0f}}, {{7.0f, 8.0f}}});
+
+ // Slice in the middle.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {0, 1, 1}, {2, 2, 1},
+ {{{4.0f}, {6.0f}}, {{10.0f}, {12.0f}}});
+
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {0, 2, 1}, {2, 2, 1},
+ {{{6.0f}, {2.0f}}, {{12.0f}, {8.0f}}});
+
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void RunR1(const std::vector<float>& input_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<int64> slice_sizes,
+ const std::vector<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR1<float>(input_values);
+ builder.DynamicSlice(input, starts, slice_sizes);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR2(const Array2D<float>& input_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<int64> slice_sizes,
+ const Array2D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR2FromArray2D<float>(input_values);
+ builder.DynamicSlice(input, starts, slice_sizes);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR3(const Array3D<float>& input_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<int64> slice_sizes,
+ const Array3D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ builder.DynamicSlice(input, starts, slice_sizes);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+};
+
+XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1<int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1<int64>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1<uint64>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2<int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2<int64>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2<uint64>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3<int32>(); }
+
+XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3<int64>(); }
+
+XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3<uint64>(); }
+
+class DynamicUpdateSliceTest : public ClientLibraryTestBase {
+ protected:
+ template <typename IndexT>
+ void TestR1() {
+ // clang-format off
+ // Slice at dimension start.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {0},
+ {8.0, 9.0, 10.0, 3.0, 4.0, 5.0, 6.0, 7.0});
+ // Slice in the middle.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {2},
+ {0.0, 1.0, 8.0, 9.0, 10.0, 5.0, 6.0, 7.0});
+ // Slice at dimension boundaries.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {5},
+ {0.0, 1.0, 2.0, 3.0, 4.0, 8.0, 9.0, 10.0});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR1<IndexT>({0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0},
+ {8.0, 9.0, 10.0}, {6},
+ {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 8.0, 9.0});
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void TestR2() {
+ // clang-format off
+ // Slice at dimension start.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {0, 0},
+ {{10.0f, 11.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}});
+ // Slice in the middle.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {1, 1},
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 10.0f, 11.0f}, {7.0f, 8.0f, 9.0f}});
+ // Slice at dimension boundaries.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {2, 1},
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 10.0f, 11.0f}});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR2<IndexT>(
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f}},
+ {{10.0f, 11.0f}}, {2, 2},
+ {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 10.0f}});
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void TestR3() {
+ // R3 Shape: [2, 3, 2]
+ // clang-format off
+ // Slice at dimension start.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {{{13.0f, 14.0f}, {15.0f, 16.0f}},
+ {{17.0f, 18.0f}, {19.0f, 20.0f}}},
+ {0, 0, 0},
+ {{{13.0f, 14.0f}, {15.0f, 16.0f}, {5.0f, 6.0f}},
+ {{17.0f, 18.0f}, {19.0f, 20.0f}, {11.0f, 12.0f}}});
+ // Slice in the middle.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {{{13.0f}, {15.0f}}},
+ {1, 1, 1},
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 13.0f}, {11.0f, 15.0f}}});
+ // Slice at dimension boundaries, but with sizes that cause indices to wrap.
+ RunR3<IndexT>(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 12.0f}}},
+ {{{13.0f}, {15.0f}}},
+ {1, 2, 1},
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}},
+ {{7.0f, 8.0f}, {9.0f, 10.0f}, {11.0f, 13.0f}}});
+ // clang-format on
+ }
+
+ template <typename IndexT>
+ void RunR1(const std::vector<float>& input_values,
+ const std::vector<float>& update_values,
+ const std::vector<IndexT> slice_starts,
+ const std::vector<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR1<float>(input_values);
+ auto update = builder.ConstantR1<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR1<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR2(const Array2D<float>& input_values,
+ const Array2D<float>& update_values,
+ const std::vector<IndexT> slice_starts,
+ const Array2D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR2FromArray2D<float>(input_values);
+ auto update = builder.ConstantR2FromArray2D<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR2<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename IndexT>
+ void RunR3(const Array3D<float>& input_values,
+ const Array3D<float>& update_values,
+ const std::vector<IndexT> slice_starts,
+ const Array3D<float>& expected_values) {
+ ComputationBuilder builder(client_, TestName());
+ // Initialize and transfer dynamic slice start indices parameter.
+ ComputationDataHandle starts;
+ std::unique_ptr<GlobalData> start_data = CreateR1Parameter<IndexT>(
+ slice_starts, 0, "slice_starts", &builder, &starts);
+ // Build dynamic slice computation.
+ auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ auto update = builder.ConstantR3FromArray3D<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+ // Run computation and compare against expected values.
+ ComputeAndCompareR3<float>(&builder, expected_values, {start_data.get()},
+ ErrorSpec(0.000001));
+ }
+
+ void RunR3Contiguous(std::vector<int32> operand_shape, int32 index,
+ int32 size) {
+ const int32 kSeq = operand_shape[0];
+ const int32 kBatch = operand_shape[1];
+ const int32 kDim = operand_shape[2];
+ Array3D<float> input_values(kSeq, kBatch, kDim);
+ Array3D<float> update_values(size, kBatch, kDim);
+ Array3D<float> expected_values(kSeq, kBatch, kDim);
+
+ input_values.FillIota(0);
+ float val = 1000;
+ update_values.FillIota(val);
+
+ // TODO(b/34128753) Expected values may vary depending on backend when
+ // the update wraps. According to documentation, the results are technically
+ // implementation specific where the update is out of bounds, and hence
+ // we don't really know what to pass into ComputeAndCompareR3.
+ expected_values.FillIota(0);
+ for (int i = 0; i < size; i++) {
+ for (int j = 0; j < kBatch; j++) {
+ for (int k = 0; k < kDim; k++) {
+ expected_values((index + i) % kSeq, j, k) = val++;
+ }
+ }
+ }
+ if (VLOG_IS_ON(1)) {
+ DumpArray<float>("input", input_values);
+ DumpArray<float>("update", update_values);
+ DumpArray<float>("expected", expected_values);
+ }
+
+ // Build dynamic slice computation.
+ ComputationBuilder builder(client_, TestName());
+ auto starts = builder.ConstantR1<int32>({index, 0, 0});
+ auto input = builder.ConstantR3FromArray3D<float>(input_values);
+ auto update = builder.ConstantR3FromArray3D<float>(update_values);
+ builder.DynamicUpdateSlice(input, update, starts);
+
+ // Run computation and compare against expected values.
+ ComputeAndCompareR3<float>(&builder, expected_values, {},
+ ErrorSpec(0.000001));
+ }
+
+ template <typename NativeT>
+ void DumpArray(const string& name, const Array3D<NativeT> values) {
+ std::unique_ptr<Literal> literal =
+ LiteralUtil::CreateR3FromArray3D<NativeT>(values);
+ LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal);
+ }
+};
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64>(); }
+
+XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64>(); }
+
+// Tests for simple R3 case where the update is contiguous (i.e. the minor
+// two dimensions are not sliced).
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) {
+ // Single element, no wrap.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) {
+ // Multiple element, no wrap.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2);
+}
+
+// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle
+// wrapping as expected.
+XLA_TEST_F(DynamicUpdateSliceTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousMultipleWrapping))) {
+ // Multiple element, wrapping.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2);
+}
+
+// TODO(b/34128753) CPU and GPU failed on 2016-01-06. Appears not to handle
+// wrapping as expected.
+XLA_TEST_F(DynamicUpdateSliceTest,
+ DISABLED_ON_CPU(DISABLED_ON_GPU(R3ContiguousTooLarge))) {
+ // Multiple element, update size larger than operand.
+ std::vector<int32> operand_shape({4, 5, 2});
+ RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2);
+}
+
+XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) {
+ std::vector<int32> operand_shape({3, 123, 247});
+ RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1);
+}
+
+// TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error.
+XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) {
+ std::vector<int32> operand_shape({32, 128, 1024});
+ RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1);
+}
+
+void BM_DynamicSlice(int num_iters) {
+ tensorflow::testing::StopTiming();
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
+ auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
+ StreamExecutorMemoryAllocator allocator(platform, executors);
+ LocalClient* client =
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
+ auto* transfer_manager =
+ TransferManager::GetForPlatform(platform).ValueOrDie();
+ int device_ordinal = client->default_device_ordinal();
+
+ ComputationBuilder builder(client, "DynamicSlice");
+
+ // Create input as a constant: shape [1, 2, 3, 4]
+ auto input_literal = LiteralUtil::CreateR4(
+ {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
+ {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
+ auto input = builder.ConstantLiteral(*input_literal);
+
+ // Create dynamic slice start indices as a parameter: shape [4]
+ auto start_indices_shape = ShapeUtil::MakeShape(S32, {4});
+ auto start_indices =
+ builder.Parameter(0, start_indices_shape, "start_indices");
+ // Add DynamicSlice op to the computatation.
+ builder.DynamicSlice(input, start_indices, {1, 1, 1, 1});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ // Initialize and transfer parameter buffer.
+ auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(start_indices_shape,
+ &allocator, 0)
+ .ConsumeValueOrDie();
+
+ auto start_indices_literal = LiteralUtil::CreateR1<int32>({0, 1, 2, 3});
+ ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice(
+ executors[device_ordinal], *start_indices_literal,
+ buffer->mutable_buffer({})));
+
+ // Run some warm-up executions.
+ LocalExecuteOptions options;
+ options.set_allocator(&allocator);
+ const int kWarmups = 2;
+ for (int i = 0; i < kWarmups; ++i) {
+ auto result = client->ExecuteLocally(computation, {buffer.get()}, options);
+ ASSERT_TRUE(result.ok());
+ }
+
+ // Run benchmark.
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < num_iters; ++i) {
+ auto result = client->ExecuteLocally(computation, {buffer.get()}, options);
+ ASSERT_TRUE(result.ok());
+ }
+}
+BENCHMARK(BM_DynamicSlice);
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
new file mode 100644
index 0000000000..8e30063085
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc
@@ -0,0 +1,128 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <limits>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class FloorCeilTest : public ClientLibraryTestBase {
+ public:
+ enum Function {
+ kFloor,
+ kCeil,
+ };
+
+ // Runs a computation and comparison on expected vs f(input)
+ void TestR1F32(tensorflow::gtl::ArraySlice<float> input,
+ tensorflow::gtl::ArraySlice<float> expected, Function f) {
+ LOG(INFO) << "input: {" << tensorflow::str_util::Join(expected, ", ")
+ << "}";
+ ComputationBuilder builder(client_, TestName());
+ auto c = builder.ConstantR1<float>(input);
+ if (f == kCeil) {
+ builder.Ceil(c);
+ } else {
+ ASSERT_EQ(kFloor, f);
+ builder.Floor(c);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, /*arguments=*/{});
+ }
+
+ void TestR0F32(float input, float expected, Function f) {
+ LOG(INFO) << "input: " << expected;
+ ComputationBuilder builder(client_, TestName());
+ auto c = builder.ConstantR0<float>(input);
+ if (f == kCeil) {
+ builder.Ceil(c);
+ } else {
+ ASSERT_EQ(kFloor, f);
+ builder.Floor(c);
+ }
+ ComputeAndCompareR0<float>(&builder, expected, /*arguments=*/{});
+ }
+
+ const ErrorSpec error_spec_{0.0001};
+
+ float infinity_ = std::numeric_limits<float>::infinity();
+ float minus_infinity_ = -std::numeric_limits<float>::infinity();
+};
+
+// Interesting notes:
+// * if you pass snan the CPU doesn't canonicalize it to qnan.
+// * passing x86-based CPU's qnan to the GPU makes a different nan
+// "7fc00000=nan=nan vs 7fffffff=nan=nan"
+
+XLA_TEST_F(FloorCeilTest, R1S0Floor) { TestR1F32({}, {}, kFloor); }
+
+TEST_F(FloorCeilTest, R1Floor) {
+ TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1},
+ {0.0, -0.0, infinity_, minus_infinity_, 1.0, -1.0}, kFloor);
+}
+
+TEST_F(FloorCeilTest, R1Ceil) {
+ TestR1F32({0.0, -0.0, infinity_, minus_infinity_, 1.1, -0.1},
+ {0.0, -0.0, infinity_, minus_infinity_, 2.0, -0.0}, kCeil);
+}
+
+TEST_F(FloorCeilTest, R0Floor) {
+ TestR0F32(0.0, 0.0, kFloor);
+ TestR0F32(-0.0, -0.0, kFloor);
+ TestR0F32(infinity_, infinity_, kFloor);
+ TestR0F32(minus_infinity_, minus_infinity_, kFloor);
+ TestR0F32(1.1, 1.0, kFloor);
+ TestR0F32(-0.1, -1.0, kFloor);
+}
+
+TEST_F(FloorCeilTest, R0Ceil) {
+ TestR0F32(0.0, 0.0, kCeil);
+ TestR0F32(-0.0, -0.0, kCeil);
+ TestR0F32(infinity_, infinity_, kCeil);
+ TestR0F32(minus_infinity_, minus_infinity_, kCeil);
+ TestR0F32(1.1, 2.0, kCeil);
+ TestR0F32(-0.1, -0.0, kCeil);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc
new file mode 100644
index 0000000000..2835038c90
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/fmax_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class FmaxSimpleTest : public ClientLibraryTestBase {};
+
+TEST_F(FmaxSimpleTest, FmaxTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {-0.0, 1.0, 2.0, -3.0, -4.0, 5.0, 6.0, -7.0, -8.0, 9.0});
+ auto y = builder.ConstantR1<float>(
+ {-0.0, -1.0, -2.0, 3.0, 4.0, -5.0, -6.0, 7.0, 8.0, -9.0});
+ builder.Max(x, y);
+
+ std::vector<float> expected = {-0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
new file mode 100644
index 0000000000..7bddbfa894
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -0,0 +1,589 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+#include <algorithm>
+#include <memory>
+#include <new>
+#include <utility>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+using tensorflow::gtl::ArraySlice;
+
+namespace xla {
+namespace {
+
+const int test_width = 2, test_height = 3;
+
+const float test_float_vals[3][test_width][test_height] = {
+ {{-1.0, -1.0, 1.0}, {-3.0, 0.0, -1.0}},
+ {{-3.0, 2.0, 1.0}, {0.0, -3.0, 1.0}},
+ {{-3.0, 0.0, -3.0}, {-1.0, -2.0, 1.0}}};
+
+// Test whether fusion operations are emitted with no errors and compute
+// accurate outputs.
+class FusionTest : public HloTestBase {
+ protected:
+ template <typename T, int Arity>
+ void TestElementwise2D(HloOpcode opcode) {
+ Array2D<float> operand_data[Arity];
+ for (int i = 0; i < Arity; ++i) {
+ new (&operand_data[i]) Array2D<float>(test_width, test_height);
+ }
+ Array2D<T> answer_data(test_width, test_height);
+ for (int i = 0; i < test_width; ++i) {
+ for (int j = 0; j < test_height; ++j) {
+ float xs[Arity];
+ for (int k = 0; k < Arity; ++k) {
+ xs[k] = test_float_vals[k][i][j];
+ operand_data[k](i, j) = xs[k];
+ }
+ answer_data(i, j) = ComputeElementwiseAnswer<T>(opcode, xs);
+ }
+ }
+
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+
+ auto prim_type = primitive_util::NativeToPrimitiveType<T>();
+
+ HloInstruction* hlos[4];
+ for (int i = 0; i < Arity; ++i) {
+ hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2FromArray2D(operand_data[i])));
+ }
+ auto answer_shape =
+ ShapeUtil::MakeShape(prim_type, {test_width, test_height});
+ std::unique_ptr<HloInstruction> root_hlo;
+ switch (Arity) {
+ case 1:
+ root_hlo = HloInstruction::CreateUnary(answer_shape, opcode, hlos[1]);
+ break;
+ case 2:
+ root_hlo = HloInstruction::CreateBinary(answer_shape, opcode, hlos[1],
+ hlos[2]);
+ break;
+ case 3:
+ root_hlo = HloInstruction::CreateTernary(answer_shape, opcode, hlos[1],
+ hlos[2], hlos[3]);
+ break;
+ default:
+ LOG(FATAL) << "Bad arity: " << Arity;
+ }
+ hlos[0] = builder.AddInstruction(std::move(root_hlo));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(
+ ArraySlice<HloInstruction*>(hlos, 0, Arity + 1),
+ HloInstruction::FusionKind::kLoop);
+
+ auto expected = LiteralUtil::CreateR2FromArray2D(answer_data);
+ auto actual = ExecuteAndTransfer(std::move(hlo_module), {});
+ if (primitive_util::IsFloatingPointType(prim_type)) {
+ LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4));
+ } else {
+ LiteralTestUtil::ExpectEqual(*expected, *actual);
+ }
+ }
+
+ private:
+ template <typename T>
+ T ComputeElementwiseAnswer(HloOpcode opcode, ArraySlice<float> xs);
+};
+
+template <>
+float FusionTest::ComputeElementwiseAnswer<float>(HloOpcode opcode,
+ ArraySlice<float> xs) {
+ switch (opcode) {
+ case HloOpcode::kAdd:
+ return xs[0] + xs[1];
+ case HloOpcode::kSubtract:
+ return xs[0] - xs[1];
+ case HloOpcode::kMultiply:
+ return xs[0] * xs[1];
+ case HloOpcode::kDivide:
+ return xs[0] / xs[1];
+ case HloOpcode::kPower:
+ return powf(xs[0], xs[1]);
+ case HloOpcode::kMinimum:
+ return std::min(xs[0], xs[1]);
+ case HloOpcode::kMaximum:
+ return std::max(xs[0], xs[1]);
+ case HloOpcode::kClamp:
+ return std::min(xs[2], std::max(xs[1], xs[0]));
+ default:
+ LOG(FATAL) << "No elementwise opcode: " << opcode;
+ }
+}
+
+template <>
+uint8 FusionTest::ComputeElementwiseAnswer<uint8>(HloOpcode opcode,
+ ArraySlice<float> xs) {
+ switch (opcode) {
+ case HloOpcode::kEq:
+ return xs[0] == xs[1];
+ case HloOpcode::kNe:
+ return xs[0] != xs[1];
+ case HloOpcode::kGt:
+ return xs[0] > xs[1];
+ case HloOpcode::kLt:
+ return xs[0] < xs[1];
+ case HloOpcode::kGe:
+ return xs[0] >= xs[1];
+ case HloOpcode::kLe:
+ return xs[0] <= xs[1];
+ default:
+ LOG(FATAL) << "No comparatory opcode: " << opcode;
+ }
+}
+
+XLA_TEST_F(FusionTest, Test) {
+ // test expression:
+ // slice(select({{T, F, T}, {F, T, F}},
+ // concat(transpose({{1.0}, {2.0}, {3.0}} +
+ // {{-1.0}, {-1.0}, {-1.0}}),
+ // {{1.62, 2.72, 3.14}}) +
+ // (-{{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}),
+ // {{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})) = {{0.5}, {2.72}}
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0}, {2.0}, {3.0}})));
+ auto const1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{-1.0}, {-1.0}, {-1.0}})));
+ auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1));
+ auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0}));
+ auto const4 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.62, 2.72, 3.14}})));
+ auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate(
+ ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0));
+ auto const6 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}})));
+ auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6));
+ auto add8 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7));
+ auto const9 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}})));
+ auto const10 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<bool>(
+ {{true, false, true}, {false, true, false}})));
+ auto select11 = builder.AddInstruction(
+ HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}),
+ HloOpcode::kSelect, const10, add8, const9));
+ auto slice12 = builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {2, 1}), select11, {0, 1}, {2, 2}));
+ // CreateFusionInstruction needs the `instructions_to_fuse` argument in
+ // reverse topological order, so the first element in `instructions_to_fuse`
+ // must be the root.
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(
+ {slice12, select11, const10, const9, add8, negate7, const6, concat5,
+ const4, reshape3, add2, const1, const0},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>({{0.5}, {2.72}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}),
+ ErrorSpec(1e-4));
+}
+
+// Test whether we emit appropriate code for parameters of fusion instructions.
+XLA_TEST_F(FusionTest, Parameter) {
+ // Build a computation and fuse part of it so the fusion instruction has an
+ // operand parameter.
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}})));
+ auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0));
+ auto const2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{-2.0, -2.0, -2.0}})));
+ // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1}
+ auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2));
+ // CreateFusionInstruction needs `instructions_to_fuse` in reverse topological
+ // order.
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2<float>({{-1.0, 0.0, 1.0}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}),
+ ErrorSpec(1e-4));
+}
+
+XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
+ auto const_array = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}})));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1}));
+ // add2 = broadcast(const_vector) + const_array
+ // = broadcast({1,2,3}) + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
+ // = {{1, 2, 3}, {1, 2, 3}} + {{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}
+ auto add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {2, 3}),
+ HloOpcode::kAdd, broadcast, const_array));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{add2, broadcast},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectNear(
+ *LiteralUtil::CreateR2<float>({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4));
+}
+
+XLA_TEST_F(FusionTest, ReshapeToScalar) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto single_element_array = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR2<int32>({{5}})));
+ auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {}), single_element_array));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(5),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}})));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {1, 2, 3}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR3<int32>({{{1, 2, 3}, {4, 5, 6}}})));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 2}, {3, 4}, {5, 6}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_1by1by1_) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR3<int32>({{{7}}})));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape__1by1by1) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(
+ ShapeUtil::MakeShape(S32, {1, 1, 1}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR3<int32>({{{7}}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape__) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(7)));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(7),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reshape_3by3_3by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ auto reshape1 = builder.AddInstruction(
+ HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Transpose_2by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}})));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0}));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 4}, {2, 5}, {3, 6}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Transpose_3by3) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}})));
+ auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0}));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1},
+ HloInstruction::FusionKind::kLoop);
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Reverse) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({1, 2, 3})));
+ auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse(
+ ShapeUtil::MakeShape(S32, {3}), const0, {0}));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1<int32>({3, 2, 1}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+std::unique_ptr<HloComputation> MakeReduceTestComputation() {
+ auto builder = HloComputation::Builder("add");
+ auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(S32, {}), "lhs"));
+ auto rhs = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(S32, {}), "rhs"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, lhs, rhs));
+ return builder.Build();
+}
+
+XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) {
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
+ hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0<int32>(15),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) {
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR1<int32>({1, 2, 4, 8})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce(
+ ShapeUtil::MakeShape(S32, {}), const0, const1, {0},
+ hlo_module->AddEmbeddedComputation(MakeReduceTestComputation())));
+ auto negate3 = builder.AddInstruction(HloInstruction::CreateUnary(
+ ShapeUtil::MakeShape(S32, {1}), HloOpcode::kNegate, reduce2));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1<int32>({-15}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) {
+ auto builder = HloComputation::Builder(TestName());
+ auto hlo_module = MakeUnique<HloModule>(TestName());
+ auto const0 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<int32>({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}})));
+ auto const1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(1)));
+ Window window;
+ ASSERT_TRUE(
+ tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n"
+ "size:2\n"
+ "stride:1\n"
+ "padding_low:0\n"
+ "padding_high:0\n"
+ "window_dilation:1\n"
+ "base_dilation:1\n"
+ "}\n"
+ "dimensions:{\n"
+ "size:2\n"
+ "stride:1\n"
+ "padding_low:0\n"
+ "padding_high:0\n"
+ "window_dilation:1\n"
+ "base_dilation:1\n"
+ "}\n",
+ &window));
+ auto nested_builder = HloComputation::Builder("mul");
+ {
+ auto x = nested_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(S32, {}), "x"));
+ auto y = nested_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(S32, {}), "y"));
+ nested_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(S32, {}), HloOpcode::kMultiply, x, y));
+ }
+ auto nested_computation =
+ hlo_module->AddEmbeddedComputation(nested_builder.Build());
+ auto reduce_window2 =
+ builder.AddInstruction(HloInstruction::CreateReduceWindow(
+ ShapeUtil::MakeShape(S32, {2, 2}), const0, const1, window,
+ nested_computation));
+ hlo_module->AddEntryComputation(builder.Build())
+ ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce_window2},
+ HloInstruction::FusionKind::kLoop);
+
+ LiteralTestUtil::ExpectEqual(
+ *LiteralUtil::CreateR2<int32>({{462, 2145}, {24871, 62491}}),
+ *ExecuteAndTransfer(std::move(hlo_module), {}));
+}
+
+XLA_TEST_F(FusionTest, Add2D) { TestElementwise2D<float, 2>(HloOpcode::kAdd); }
+
+XLA_TEST_F(FusionTest, Subtract2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kSubtract);
+}
+
+XLA_TEST_F(FusionTest, Multiply2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kMultiply);
+}
+
+XLA_TEST_F(FusionTest, Divide2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kDivide);
+}
+
+XLA_TEST_F(FusionTest, Power2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kPower);
+}
+
+XLA_TEST_F(FusionTest, Minimum2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kMinimum);
+}
+
+XLA_TEST_F(FusionTest, Maximum2D) {
+ TestElementwise2D<float, 2>(HloOpcode::kMaximum);
+}
+
+XLA_TEST_F(FusionTest, Equal2D) { TestElementwise2D<uint8, 2>(HloOpcode::kEq); }
+
+XLA_TEST_F(FusionTest, Inequal2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kNe);
+}
+
+XLA_TEST_F(FusionTest, Greater2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kGt);
+}
+
+XLA_TEST_F(FusionTest, Lesser2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kLt);
+}
+
+XLA_TEST_F(FusionTest, GreaterOrEqual2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kGe);
+}
+
+XLA_TEST_F(FusionTest, LesserOrEqual2D) {
+ TestElementwise2D<uint8, 2>(HloOpcode::kLe);
+}
+
+XLA_TEST_F(FusionTest, Clamp2D) {
+ TestElementwise2D<float, 3>(HloOpcode::kClamp);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
new file mode 100644
index 0000000000..872188de81
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+
+#include <set>
+#include <string>
+#include <utility>
+
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/hlo_test_base_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/executable.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
+#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/shape_layout.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+
+// Define this in .cc file to avoid having to include eigen or forward declare
+// these types in the header.
+struct HloTestBase::EigenThreadPoolWrapper {
+ std::unique_ptr<EigenThreadPoolWrapper> pool;
+ std::unique_ptr<Eigen::ThreadPoolDevice> device;
+};
+
+HloTestBase::HloTestBase()
+ : backend_(Backend::CreateDefaultBackend().ConsumeValueOrDie()) {
+ test_hlo_dumper_ = [](const HloModule& module, const string& label) {
+ legacy_flags::HloTestBaseFlags* flags = legacy_flags::GetHloTestBaseFlags();
+ if (flags->xla_hlo_test_generate_hlo_graph) {
+ const bool show_addresses = true;
+ const bool show_layouts = true;
+ hlo_graph_dumper::DumpGraph(*module.entry_computation(), label,
+ show_addresses, show_layouts);
+ }
+ };
+ VLOG(1) << "executing on platform " << backend_->platform()->Name();
+}
+
+HloTestBase::~HloTestBase() {
+ // Deallocate all the memory allocated during the tests.
+ for (auto& allocation : allocations_) {
+ backend_->default_stream_executor()->Deallocate(&allocation);
+ }
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape) {
+ auto module_config = MakeUnique<HloModuleConfig>(
+ MakeProgramShape(module->entry_computation()));
+ return Execute(std::move(module), std::move(module_config), arguments,
+ result_shape);
+}
+
+StatusOr<se::DeviceMemoryBase> HloTestBase::Execute(
+ std::unique_ptr<HloModule> hlo_module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+ Shape* result_shape) {
+ VLOG(3) << "module_config layout "
+ << LayoutUtil::HumanString(module_config->entry_computation_layout()
+ .result_layout()
+ .layout());
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<Executable> executable,
+ backend_->compiler()->Compile(std::move(hlo_module),
+ std::move(module_config), test_hlo_dumper_,
+ backend_->default_stream_executor()));
+
+ se::Stream stream(backend_->default_stream_executor());
+ stream.Init();
+
+ ExecutableRunOptions run_options;
+ run_options.set_stream(&stream);
+ run_options.set_allocator(backend_->memory_allocator());
+ run_options.set_inter_op_thread_pool(backend_->inter_op_thread_pool());
+ run_options.set_intra_op_thread_pool(
+ backend_->eigen_intra_op_thread_pool_device());
+
+ HloExecutionProfile hlo_execution_profile;
+ TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase result,
+ executable->ExecuteOnStream(&run_options, arguments,
+ &hlo_execution_profile));
+ TF_RET_CHECK(stream.BlockHostUntilDone());
+
+ allocations_.push_back(result);
+
+ *result_shape = executable->result_shape();
+
+ if (ShapeUtil::IsTuple(*result_shape)) {
+ // We must record element buffers of tuples as well to avoid leaks.
+ DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
+ TF_ASSIGN_OR_RETURN(
+ std::vector<se::DeviceMemoryBase> element_buffers,
+ backend_->transfer_manager()->ShallowCopyTupleFromDevice(
+ backend_->default_stream_executor(), result, *result_shape));
+
+ // A tuple may contain the same buffer in more than one element. Keep track
+ // of the buffers already added to avoid duplicates in allocations_.
+ std::set<void*> added_opaques;
+ for (auto element_buffer : element_buffers) {
+ if (added_opaques.count(element_buffer.opaque()) == 0) {
+ added_opaques.insert(element_buffer.opaque());
+ allocations_.push_back(element_buffer);
+ }
+ }
+ }
+
+ return result;
+}
+
+se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
+ // Allocate memory on the device using the stream executor.
+ int64 allocation_size =
+ backend_->transfer_manager()->GetByteSizeRequirement(literal.shape());
+ se::DeviceMemoryBase allocation =
+ backend_->default_stream_executor()->AllocateArray<uint8>(
+ allocation_size);
+ allocations_.push_back(allocation);
+
+ TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralToDevice(
+ backend_->default_stream_executor(), literal, &allocation));
+
+ return allocation;
+}
+
+std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
+ const Shape& shape, se::DeviceMemoryBase device_base) {
+ auto literal = MakeUnique<Literal>();
+ TF_CHECK_OK(backend_->transfer_manager()->TransferLiteralFromDevice(
+ backend_->default_stream_executor(), device_base, shape, shape,
+ literal.get()));
+ return literal;
+}
+
+std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ Shape result_shape;
+ se::DeviceMemoryBase device_base =
+ Execute(std::move(module), arguments, &result_shape).ValueOrDie();
+ return TransferFromDevice(result_shape, device_base);
+}
+
+std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
+ Shape result_shape;
+ se::DeviceMemoryBase device_base =
+ Execute(std::move(module), std::move(module_config), arguments,
+ &result_shape)
+ .ValueOrDie();
+ return TransferFromDevice(result_shape, device_base);
+}
+
+ProgramShape HloTestBase::MakeProgramShape(HloComputation* computation) {
+ ProgramShape program_shape;
+ for (int64 i = 0; i < computation->num_parameters(); ++i) {
+ *program_shape.add_parameters() =
+ computation->parameter_instruction(i)->shape();
+ }
+ *program_shape.mutable_result() = computation->root_instruction()->shape();
+ return program_shape;
+}
+
+string HloTestBase::TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
new file mode 100644
index 0000000000..fa88c76899
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/compiler.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_module_config.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+
+// A base class for tests which build and run HLO code. This is a lower level of
+// abstraction than using the client interface and enables, for one, explicitly
+// building a graph of HLO instructions to run.
+class HloTestBase : public ::testing::Test {
+ protected:
+ struct EigenThreadPoolWrapper;
+ HloTestBase();
+
+ ~HloTestBase() override;
+
+ // Executes the given module and returns a global data handle.
+ StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape);
+
+ // Variation of Execute which takes a custom module_config instead of creating
+ // a default one.
+ StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments,
+ Shape* result_shape);
+
+ // Transfers the given literal to the device and returns the data handle.
+ perftools::gputools::DeviceMemoryBase TransferToDevice(
+ const Literal& literal);
+
+ // Transfers the array refered to by the given handle from the device and
+ // returns as a Literal.
+ std::unique_ptr<Literal> TransferFromDevice(
+ const Shape& shape, perftools::gputools::DeviceMemoryBase device_base);
+
+ // Executes the given module and return the result as a Literal.
+ std::unique_ptr<Literal> ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments);
+
+ // Variation of ExecuteAndTransfer which takes a custom module_config instead
+ // of creating a default one.
+ std::unique_ptr<Literal> ExecuteAndTransfer(
+ std::unique_ptr<HloModule> module,
+ std::unique_ptr<HloModuleConfig> module_config,
+ tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+ arguments);
+
+ // Utility function which creates a ProgramShape for a given computation.
+ ProgramShape MakeProgramShape(HloComputation* computation);
+
+ string TestName() const;
+
+ std::unique_ptr<Backend> backend_;
+
+ Compiler::HloDumper test_hlo_dumper_;
+
+ // This vector contains handles of all the device memory allocations performed
+ // by the test. These are deallocated on destruction of the test object.
+ std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
+
+ ErrorSpec error_spec_{0.0001};
+
+ std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/inprocess_service_test.cc b/tensorflow/compiler/xla/tests/inprocess_service_test.cc
new file mode 100644
index 0000000000..9909f041de
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/inprocess_service_test.cc
@@ -0,0 +1,204 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+// Tests which exercise the "InProcess" methods of xla::Client. The
+// "InProcess" methods require that the client and server share the same
+// process.
+class InProcessServiceTest : public ClientLibraryTestBase {
+ protected:
+ std::unique_ptr<GlobalData> ExecuteR2F32Constant(
+ std::initializer_list<std::initializer_list<float>> values,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR2<float>(values);
+ auto computation = builder.Build().ConsumeValueOrDie();
+ CHECK_EQ(2, minor_to_major.size());
+ Shape shape_with_layout = ShapeUtil::MakeShapeWithLayout(
+ F32,
+ /*dimensions=*/{static_cast<int64>(values.size()),
+ static_cast<int64>(values.begin()->size())},
+ minor_to_major);
+ return client_
+ ->Execute(computation, {}, &shape_with_layout,
+ /*execution_profile=*/nullptr)
+ .ConsumeValueOrDie();
+ }
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+XLA_TEST_F(InProcessServiceTest, TransferFromServer) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<int32>({1, 42, 5});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto handle = client_->Execute(computation, {}).ConsumeValueOrDie();
+
+ std::vector<int32> result(3, 0);
+ ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data()));
+ EXPECT_MATCH(result, testing::VectorMatcher<int32>({1, 42, 5}));
+}
+
+XLA_TEST_F(InProcessServiceTest, TransferToServer) {
+ std::vector<float> input{1.0f, 2.0f, -42.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {3});
+ auto data_handle = client_->TransferToServerInProcess(shape, input.data())
+ .ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto param = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "param");
+ builder.Add(param, param);
+
+ ComputeAndCompareR1<float>(&builder, {2.0f, 4.0f, -84.0f},
+ {data_handle.get()}, error_spec_);
+}
+
+// TODO(b/28506710): This test case seems not to test inprocess
+// methods.
+TEST_F(InProcessServiceTest, GetShape) {
+ ComputationBuilder builder(client_, TestName());
+ builder.ConstantR1<int32>({1, 42, 5});
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto handle = client_->Execute(computation, {}).ConsumeValueOrDie();
+
+ Shape shape = client_->GetShape(*handle).ConsumeValueOrDie();
+ ASSERT_EQ(S32, shape.element_type());
+ ASSERT_EQ(1, ShapeUtil::Rank(shape));
+ ASSERT_EQ(3, shape.dimensions(0));
+}
+
+XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayRowMajor) {
+ std::vector<float> input{1.0f, 2.0f, 3.0f, 4.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ shape.clear_layout();
+ *shape.mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ auto handle = client_->TransferToServerInProcess(shape, input.data())
+ .ConsumeValueOrDie();
+
+ Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned));
+}
+
+XLA_TEST_F(InProcessServiceTest, GetShapeOfClientSuppliedArrayColMajor) {
+ std::vector<float> input{1.0f, 2.0f, 3.0f, 4.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 2});
+ shape.clear_layout();
+ *shape.mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ auto handle = client_->TransferToServerInProcess(shape, input.data())
+ .ConsumeValueOrDie();
+
+ Shape shape_returned = client_->GetShape(*handle).ConsumeValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(shape, shape_returned));
+}
+
+TEST_F(InProcessServiceTest, TransferToServerNoLayout) {
+ std::vector<float> input{1.0f, 2.0f, -42.0f};
+ Shape shape = ShapeUtil::MakeShape(F32, {3});
+ shape.clear_layout();
+ auto transfer_status =
+ client_->TransferToServerInProcess(shape, input.data());
+ ASSERT_EQ(transfer_status.status().code(),
+ tensorflow::error::INVALID_ARGUMENT);
+}
+
+XLA_TEST_F(InProcessServiceTest, ExecuteRowMajor) {
+ auto handle =
+ ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0});
+
+ std::vector<float> result(4, 0.0);
+ Shape shape;
+ ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data()));
+
+ EXPECT_MATCH(result, testing::VectorMatcher<float>({1.0, 2.0, 3.0, 4.0}));
+}
+
+XLA_TEST_F(InProcessServiceTest, ExecuteColumnMajor) {
+ auto handle =
+ ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{0, 1});
+
+ std::vector<float> result(4, 0);
+ Shape shape;
+ ASSERT_IS_OK(client_->TransferInProcess(*handle, result.data()));
+
+ EXPECT_MATCH(result, testing::VectorMatcher<float>({1.0, 3.0, 2.0, 4.0}));
+}
+
+XLA_TEST_F(InProcessServiceTest, ExecuteAndReuseDifferentLayouts) {
+ // Create arrays on the server which have different layouts. Verify the
+ // computation still produces the correct results.
+ auto handle_rowmaj =
+ ExecuteR2F32Constant({{1.0, 2.0}, {3.0, 4.0}}, /*minor_to_major=*/{1, 0});
+
+ auto handle_colmaj = ExecuteR2F32Constant({{10.0, 20.0}, {30.0, 40.0}},
+ /*minor_to_major=*/{0, 1});
+
+ ComputationBuilder builder(client_, TestName());
+ auto param0 =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "param0");
+ auto param1 =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "param1");
+ builder.Add(param0, param1);
+
+ Array2D<float> expected({{11.0, 22.0}, {33.0, 44.0}});
+ ComputeAndCompareR2<float>(&builder, expected,
+ {handle_rowmaj.get(), handle_colmaj.get()},
+ error_spec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc
new file mode 100644
index 0000000000..f7bbc0f38b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/literal_test_util.cc
@@ -0,0 +1,566 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+
+#include <unistd.h>
+#include <cmath>
+#include <vector>
+
+#include "tensorflow/compiler/xla/index_util.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected,
+ const Shape& actual) {
+ ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual));
+ ASSERT_EQ(expected.element_type(), actual.element_type())
+ << PrimitiveType_Name(expected.element_type()) << " vs "
+ << PrimitiveType_Name(actual.element_type());
+ ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size());
+ for (int i = 0; i < expected.dimensions_size(); ++i) {
+ ASSERT_EQ(expected.dimensions(i), actual.dimensions(i))
+ << "mismatch in dimension #" << i
+ << " expected: " << ShapeUtil::HumanString(expected)
+ << " actual: " << ShapeUtil::HumanString(actual);
+ }
+ ASSERT_EQ(expected.tuple_shapes_size(), actual.tuple_shapes_size());
+ for (int i = 0; i < expected.tuple_shapes_size(); ++i) {
+ AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i));
+ }
+}
+
+/* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts(
+ const Shape& expected, const Shape& actual) {
+ ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString());
+}
+
+namespace {
+
+string Hostname() {
+ char hostname[1024];
+ gethostname(hostname, sizeof hostname);
+ hostname[sizeof hostname - 1] = 0;
+ return string(hostname);
+}
+
+// Helper function for comparing a floating point type, FloatT, bitwise equal
+// between the left-hand-side and right-hand-side, by bit-casting to UnsignedT
+// -- on miscompare, a nice error message is given in the AssertionFailure.
+template <typename FloatT, typename UnsignedT>
+testing::AssertionResult CompareFloatsBitwiseEqual(FloatT lhs, FloatT rhs) {
+ auto ulhs = tensorflow::bit_cast<UnsignedT>(lhs);
+ auto urhs = tensorflow::bit_cast<UnsignedT>(rhs);
+ if (ulhs != urhs) {
+ return testing::AssertionFailure() << tensorflow::strings::Printf(
+ "floating values are not bitwise-equal; and equality testing "
+ "was requested: %s=%g=%a vs %s=%g=%a",
+ tensorflow::strings::StrCat(tensorflow::strings::Hex(ulhs))
+ .c_str(),
+ lhs, lhs,
+ tensorflow::strings::StrCat(tensorflow::strings::Hex(urhs))
+ .c_str(),
+ rhs, rhs);
+ }
+ return testing::AssertionSuccess();
+}
+
+// Templated comparator that specializes for float equality comparison with the
+// bitwise helper above (this is the un-specialized fallback, to just use the
+// default gunit implementation).
+template <typename NativeT>
+testing::AssertionResult CompareEqual(NativeT lhs, NativeT rhs) {
+ if (lhs == rhs) {
+ return testing::AssertionSuccess();
+ }
+ ::testing::Message msg;
+ msg << "Expected equality of these values:";
+ msg << "\n " << lhs;
+ msg << "\n " << rhs;
+
+ return testing::AssertionFailure() << msg;
+}
+
+// Specializations for floating types that do bitwise comparisons when equality
+// comparison is requested.
+template <>
+testing::AssertionResult CompareEqual<float>(float lhs, float rhs) {
+ return CompareFloatsBitwiseEqual<float, uint32>(lhs, rhs);
+}
+template <>
+testing::AssertionResult CompareEqual<double>(double lhs, double rhs) {
+ return CompareFloatsBitwiseEqual<double, uint64>(lhs, rhs);
+}
+
+// A recursive function which iterates through every index of expected and
+// actual literal and compares their values elementwise. Returns true if all
+// elements are equal.
+template <typename NativeT>
+bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual,
+ tensorflow::gtl::MutableArraySlice<int64> multi_index,
+ int64 dimension) {
+ if (dimension == expected.shape().dimensions_size()) {
+ NativeT expected_value = LiteralUtil::Get<NativeT>(expected, multi_index);
+ NativeT actual_value = LiteralUtil::Get<NativeT>(actual, multi_index);
+ testing::AssertionResult result =
+ CompareEqual<NativeT>(expected_value, actual_value);
+ return result; // Defines implicit coersion to bool.
+ }
+
+ bool all_match = true;
+ for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+ multi_index[dimension] = i;
+ all_match = all_match && ExpectLiteralsEqual<NativeT>(
+ expected, actual, multi_index, dimension + 1);
+ }
+ return all_match;
+}
+
+} // namespace
+
+/* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected,
+ const Literal& actual) {
+ EXPECT_TRUE(Equal(expected, actual)) << "expected:\n"
+ << LiteralUtil::ToString(expected)
+ << "\n\tvs actual:\n"
+ << LiteralUtil::ToString(actual);
+}
+
+/* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected,
+ const Literal& actual) {
+ EXPECT_FALSE(Equal(expected, actual));
+}
+
+/* static */ testing::AssertionResult LiteralTestUtil::Equal(
+ const Literal& expected, const Literal& actual) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ AssertEqualShapes(expected.shape(), actual.shape());
+ std::vector<int64> multi_index(expected.shape().dimensions_size(), 0);
+ bool match = false;
+ switch (expected.shape().element_type()) {
+ case PRED:
+ match = ExpectLiteralsEqual<bool>(expected, actual, &multi_index, 0);
+ break;
+ case U8:
+ match = ExpectLiteralsEqual<uint8>(expected, actual, &multi_index, 0);
+ break;
+ case S32:
+ match = ExpectLiteralsEqual<int32>(expected, actual, &multi_index, 0);
+ break;
+ case S64:
+ match = ExpectLiteralsEqual<int64>(expected, actual, &multi_index, 0);
+ break;
+ case U32:
+ match = ExpectLiteralsEqual<uint32>(expected, actual, &multi_index, 0);
+ break;
+ case U64:
+ match = ExpectLiteralsEqual<uint64>(expected, actual, &multi_index, 0);
+ break;
+ case F32:
+ match = ExpectLiteralsEqual<float>(expected, actual, &multi_index, 0);
+ break;
+ case F64:
+ match = ExpectLiteralsEqual<double>(expected, actual, &multi_index, 0);
+ break;
+ case TUPLE: {
+ bool tuple_match = true;
+ for (int i = 0; i < actual.tuple_literals_size(); ++i) {
+ auto result =
+ Equal(expected.tuple_literals(i), actual.tuple_literals(i));
+ tuple_match = tuple_match ? !!result : false;
+ }
+ match = tuple_match;
+ break;
+ }
+ default:
+ LOG(FATAL)
+ << "Unsupported primitive type in LiteralTestUtil::ExpectEqual: "
+ << PrimitiveType_Name(expected.shape().element_type());
+ }
+ testing::AssertionResult result = testing::AssertionSuccess();
+ if (!match) {
+ result = testing::AssertionFailure()
+ << "expected: " << LiteralUtil::ToString(expected)
+ << "\nactual: " << LiteralUtil::ToString(actual);
+ VLOG(1) << result.message();
+ }
+ return result;
+}
+
+/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected,
+ const Literal& actual) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape()));
+ ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape()));
+ AssertEqualShapes(expected.shape(), actual.shape());
+ for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+ const auto& expected_element = expected.tuple_literals(i);
+ const auto& actual_element = actual.tuple_literals(i);
+ if (ShapeUtil::IsTuple(expected_element.shape())) {
+ ExpectEqualTuple(expected_element, actual_element);
+ } else {
+ ExpectEqual(expected_element, actual_element);
+ }
+ }
+}
+
+namespace {
+
+// Helper class for comparing floating-point literals within an error bound.
+class NearComparator {
+ public:
+ explicit NearComparator(ErrorSpec error) : error_(error) {}
+
+ // Compares the two literals elementwise. EXPECTs each pair of elements to be
+ // within the error bound. Emits useful log messages and dumps literals to
+ // temporary files on failure. Returns true if literals match.
+ bool ExpectNear(const Literal& expected, const Literal& actual) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape());
+
+ // Set up members used during the comparison.
+ num_miscompares_ = 0;
+ abs_diff_sum_ = 0.0;
+ abs_expected_sum_ = 0.0;
+ abs_diff_miscompare_sum_ = 0.0;
+ abs_expected_miscompare_sum_ = 0.0;
+ max_rel_err_ = 0.0;
+ max_abs_err_ = 0.0;
+ *miscompares_.mutable_shape() =
+ ShapeUtil::ChangeElementType(actual.shape(), PRED);
+ miscompares_.mutable_preds()->Resize(
+ ShapeUtil::ElementsIn(miscompares_.shape()), false);
+ multi_index_.resize(expected.shape().dimensions_size(), 0);
+
+ switch (expected.shape().element_type()) {
+ case F32:
+ ExpectLiteralsNear<float>(expected, actual, 0);
+ break;
+ case F64:
+ ExpectLiteralsNear<double>(expected, actual, 0);
+ break;
+ default:
+ LOG(FATAL) << "Unsupported primitive type in near comparator: "
+ << PrimitiveType_Name(expected.shape().element_type())
+ << ". Must be floating-point type.";
+ }
+
+ if (num_miscompares_ > 0) {
+ if (!VLOG_IS_ON(1)) {
+ LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape())
+ << " " << LiteralUtil::ToString(expected);
+ LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape())
+ << " " << LiteralUtil::ToString(actual);
+ }
+ EXPECT_TRUE(num_miscompares_ == 0)
+ << "\nmax relative mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_)
+ << "\nmaximum relative error " << max_rel_err_
+ << "\nmax absolute mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_)
+ << "\nmaximum absolute error " << max_abs_err_
+ << "\nfirst mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(first_multi_index_)
+ << "\nlast mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(last_multi_index_)
+ << "\ntotal absolute error " << abs_diff_sum_
+ << "\ntotal absolute error of miscompares "
+ << abs_diff_miscompare_sum_ << "\ntotal relative error "
+ << (abs_diff_sum_ / abs_expected_sum_)
+ << "\ntotal relative error of miscompares "
+ << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_)
+ << "\nfailure count " << num_miscompares_;
+
+ WriteLiteralToTempFile(expected, "expected");
+ WriteLiteralToTempFile(actual, "actual");
+ WriteLiteralToTempFile(miscompares_, "miscompares");
+ }
+ return num_miscompares_ == 0;
+ }
+
+ private:
+ // EXPECTs that the two given scalar values are within the error bound. Keeps
+ // track of how many mismatches have occured to keep the size of the output
+ // manageable.
+ template <typename NativeT>
+ bool ExpectValuesNear(NativeT expected, NativeT actual) {
+ if (expected == actual) {
+ return true;
+ }
+
+ float abs_diff = std::abs(actual - expected);
+ float rel_err = abs_diff / std::abs(expected);
+ abs_diff_sum_ += abs_diff;
+ abs_expected_sum_ += std::abs(expected);
+ if (rel_err > max_rel_err_) {
+ max_rel_err_ = rel_err;
+ max_rel_multi_index_ = multi_index_;
+ }
+ if (abs_diff > max_abs_err_) {
+ max_abs_err_ = abs_diff;
+ max_abs_multi_index_ = multi_index_;
+ }
+ VLOG(10) << tensorflow::strings::Printf(
+ "index %s abs_diff %f rel_err %f",
+ LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff,
+ rel_err);
+ bool nan_mismatch = std::isnan(actual) != std::isnan(expected);
+ bool mismatch =
+ (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel));
+ if (mismatch) {
+ abs_diff_miscompare_sum_ += abs_diff;
+ abs_expected_miscompare_sum_ += std::abs(expected);
+ const int64 kMaxFailures = 2;
+ if (num_miscompares_ < kMaxFailures) {
+ EXPECT_NEAR(expected, actual, error_.abs)
+ << "mismatch at index "
+ << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff "
+ << abs_diff << " rel err " << rel_err << " failure #"
+ << num_miscompares_;
+ } else if (num_miscompares_ == kMaxFailures) {
+ LOG(ERROR)
+ << "reached max 'loud' failure count; silently proceeding...";
+ }
+ if (num_miscompares_ == 0) {
+ first_multi_index_ = multi_index_;
+ }
+ num_miscompares_++;
+ last_multi_index_ = multi_index_;
+ }
+ return !mismatch;
+ }
+
+ // Recursive function which compares the two given literals elementwise.
+ template <typename NativeT>
+ void ExpectLiteralsNear(const Literal& expected, const Literal& actual,
+ int64 dimension) {
+ if (dimension == expected.shape().dimensions_size()) {
+ bool near =
+ ExpectValuesNear(LiteralUtil::Get<NativeT>(expected, multi_index_),
+ LiteralUtil::Get<NativeT>(actual, multi_index_));
+ LiteralUtil::Set<bool>(&miscompares_, multi_index_, !near);
+ } else {
+ for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) {
+ multi_index_[dimension] = i;
+ ExpectLiteralsNear<NativeT>(expected, actual, dimension + 1);
+ }
+ }
+ }
+
+ // Writes the given literal to a file in the test temporary directory.
+ void WriteLiteralToTempFile(const Literal& literal, const string& name) {
+ int64 now_usec = tensorflow::Env::Default()->NowMicros();
+ string filename = tensorflow::io::JoinPath(
+ tensorflow::testing::TmpDir(),
+ tensorflow::strings::Printf("tempfile-%s-%llx-%s", Hostname().c_str(),
+ now_usec, name.c_str()));
+ TF_CHECK_OK(tensorflow::WriteBinaryProto(tensorflow::Env::Default(),
+ filename, literal));
+ LOG(ERROR) << "wrote to " << name << " file: " << filename;
+ }
+
+ ErrorSpec error_;
+
+ // Number of element miscomparisons encountered so far.
+ int64 num_miscompares_;
+
+ // A Literal containing which elements did not match in the expected and
+ // actual literals. miscompares_ contains PREDs and is of the same sizes as
+ // the comparison literals.
+ Literal miscompares_;
+
+ // A multidimensional index used when performing the recursive comparison.
+ std::vector<int64> multi_index_;
+
+ // Aggregated Statistics on input.
+ double abs_diff_sum_;
+ double abs_expected_sum_;
+ double abs_diff_miscompare_sum_;
+ double abs_expected_miscompare_sum_;
+ float max_rel_err_;
+ float max_abs_err_;
+ std::vector<int64> first_multi_index_;
+ std::vector<int64> last_multi_index_;
+ std::vector<int64> max_rel_multi_index_;
+ std::vector<int64> max_abs_multi_index_;
+};
+
+} // namespace
+
+/* static */ testing::AssertionResult LiteralTestUtil::Near(
+ const Literal& expected, const Literal& actual, const ErrorSpec& error) {
+ NearComparator comparator(error);
+ return comparator.ExpectNear(expected, actual)
+ ? testing::AssertionSuccess()
+ : testing::AssertionFailure() << "values were not near";
+}
+
+/* static */ void LiteralTestUtil::ExpectNear(const Literal& expected,
+ const Literal& actual,
+ const ErrorSpec& error) {
+ EXPECT_TRUE(Near(expected, actual, error));
+}
+
+/* static */ testing::AssertionResult LiteralTestUtil::NearTuple(
+ const Literal& expected, const Literal& actual, const ErrorSpec& error) {
+ VLOG(1) << "expected: " << LiteralUtil::ToString(expected);
+ VLOG(1) << "actual: " << LiteralUtil::ToString(actual);
+
+ if (!ShapeUtil::IsTuple(expected.shape()) ||
+ !ShapeUtil::IsTuple(actual.shape())) {
+ return testing::AssertionFailure()
+ << "tuples expected expected shape = "
+ << expected.shape().ShortDebugString()
+ << " actual shape = " << actual.shape().ShortDebugString();
+ }
+ AssertEqualShapes(expected.shape(), actual.shape());
+ for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) {
+ const auto& expected_element = expected.tuple_literals(i);
+ const auto& actual_element = actual.tuple_literals(i);
+ if (ShapeUtil::IsTuple(expected_element.shape())) {
+ auto ret = NearTuple(expected_element, actual_element, error);
+ if (!ret) {
+ return ret;
+ }
+ } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) {
+ auto ret = Near(expected_element, actual_element, error);
+ if (!ret) {
+ return ret;
+ }
+ } else {
+ auto ret = Equal(expected_element, actual_element);
+ if (!ret) {
+ return ret;
+ }
+ }
+ }
+
+ return testing::AssertionSuccess();
+}
+
+/* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected,
+ const Literal& actual,
+ const ErrorSpec& error) {
+ EXPECT_TRUE(NearTuple(expected, actual, error));
+}
+
+/* static */ string LiteralTestUtil::MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index) {
+ return tensorflow::strings::StrCat(
+ "{", tensorflow::str_util::Join(multi_index, ","), "}");
+}
+
+/* static */ std::unique_ptr<Literal> LiteralTestUtil::Reshape(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major, const Literal& literal) {
+ int64 new_num_elements = 1;
+ for (int64 i = 0; i < new_dimensions.size(); ++i) {
+ new_num_elements *= new_dimensions[i];
+ }
+ CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
+
+ auto new_literal = MakeUnique<Literal>();
+ *new_literal->mutable_shape() =
+ ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions);
+
+ // Create a new shape with the given minor-to-major layout. This shape is used
+ // solely for converting linear address to multi-dimensional addresses when
+ // writing elements to the new literal.
+ Shape shape_with_layout = new_literal->shape();
+ *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
+
+ // Allocate space in the new literal.
+ LiteralUtil::Reserve(ShapeUtil::ElementsIn(literal.shape()),
+ new_literal.get());
+
+ // Copy data into new literal, element-by-element.
+ for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
+ std::vector<int64> from_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
+ std::vector<int64> to_multi_index =
+ IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
+ switch (literal.shape().element_type()) {
+ case PRED:
+ LiteralUtil::Set<bool>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<bool>(literal, from_multi_index));
+ break;
+ case U8:
+ LiteralUtil::Set<uint8>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<uint8>(literal, from_multi_index));
+ break;
+ case U32:
+ LiteralUtil::Set<uint32>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<uint32>(literal, from_multi_index));
+ break;
+ case S32:
+ LiteralUtil::Set<int32>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<int32>(literal, from_multi_index));
+ break;
+ case U64:
+ LiteralUtil::Set<uint64>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<uint64>(literal, from_multi_index));
+ break;
+ case S64:
+ LiteralUtil::Set<int64>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<int64>(literal, from_multi_index));
+ break;
+ case F32:
+ LiteralUtil::Set<float>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<float>(literal, from_multi_index));
+ break;
+ case F64:
+ LiteralUtil::Set<double>(
+ new_literal.get(), to_multi_index,
+ LiteralUtil::Get<double>(literal, from_multi_index));
+ break;
+ default:
+ LOG(FATAL) << "Unhandled primitive element type: "
+ << PrimitiveType_Name(literal.shape().element_type());
+ }
+ }
+
+ return new_literal;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h
new file mode 100644
index 0000000000..85656a53e4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/literal_test_util.h
@@ -0,0 +1,274 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
+
+#include <initializer_list>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// Structure describing permissible absolute and relative error bounds.
+struct ErrorSpec {
+ explicit ErrorSpec(float aabs, float arel = 0) : abs(aabs), rel(arel) {}
+
+ float abs; // Absolute error bound.
+ float rel; // Relative error bound.
+};
+
+// Utility class for making expectations/assertions related to XLA literals.
+class LiteralTestUtil {
+ public:
+ // Asserts that the given shapes have the same rank, dimension sizes, and
+ // primitive types.
+ static void AssertEqualShapes(const Shape& expected, const Shape& actual);
+
+ // Asserts that the provided shapes are equal as defined in AssertEqualShapes
+ // and that they have the same layout.
+ static void AssertEqualShapesAndLayouts(const Shape& expected,
+ const Shape& actual);
+
+ // Asserts that the expected and actual literals are (bitwise) equal for all
+ // elements in the literal. Also, asserts that the rank, dimensions sizes, and
+ // primitive type are equal.
+ static testing::AssertionResult Equal(
+ const Literal& expected, const Literal& actual) TF_MUST_USE_RESULT;
+
+ // Expects that expected and actual are Equal.
+ static void ExpectEqual(const Literal& expected, const Literal& actual);
+
+ // Expects that expected and actual are Not Equal.
+ static void ExpectNotEqual(const Literal& expected, const Literal& actual);
+
+ // Asserts the given literal are (bitwise) equal to given expected values.
+ template <typename NativeT>
+ static void ExpectR0Equal(NativeT expected, const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR1Equal(tensorflow::gtl::ArraySlice<NativeT> expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR2Equal(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR3Equal(
+ std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual);
+
+ // Asserts the given literal are (bitwise) equal to given array.
+ template <typename NativeT>
+ static void ExpectR2EqualArray2D(const Array2D<NativeT>& expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR3EqualArray3D(const Array3D<NativeT>& expected,
+ const Literal& actual);
+ template <typename NativeT>
+ static void ExpectR4EqualArray4D(const Array4D<NativeT>& expected,
+ const Literal& actual);
+
+ // Expects that the values of the elements in the expected and actual tuples
+ // are equal. Tuples are matched recursively.
+ static void ExpectEqualTuple(const Literal& expected, const Literal& actual);
+
+ // Asserts that the expected and actual literals are within the given error
+ // bound for all elements. Also, asserts that the rank, dimensions sizes, and
+ // bounds are equivalent. Only supported for floating point values.
+ static testing::AssertionResult Near(
+ const Literal& expected, const Literal& actual,
+ const ErrorSpec& error) TF_MUST_USE_RESULT;
+
+ // Expects expected and actual to be Near with the given error.
+ static void ExpectNear(const Literal& expected, const Literal& actual,
+ const ErrorSpec& error);
+
+ // Asserts the given literal are within the given error bound of the given
+ // expected values. Only supported for floating point values.
+ template <typename NativeT>
+ static void ExpectR0Near(NativeT expected, const Literal& actual,
+ const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR1Near(tensorflow::gtl::ArraySlice<NativeT> expected,
+ const Literal& actual, const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR2Near(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual, const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR3Near(
+ std::initializer_list<
+ std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual, const ErrorSpec& error);
+
+ // Asserts the given literal are within the given error bound to the given
+ // array. Only supported for floating point values.
+ template <typename NativeT>
+ static void ExpectR2NearArray2D(const Array2D<NativeT>& expected,
+ const Literal& actual,
+ const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR3NearArray3D(const Array3D<NativeT>& expected,
+ const Literal& actual,
+ const ErrorSpec& error);
+ template <typename NativeT>
+ static void ExpectR4NearArray4D(const Array4D<NativeT>& expected,
+ const Literal& actual,
+ const ErrorSpec& error);
+
+ // Returns whether the values of the elements in the expected and actual
+ // tuples are within the given error bound. Tuples are matched recursively.
+ // If the elements of the tuple are not floating-point types, the error spec
+ // is ignored and exact equality is checked.
+ static testing::AssertionResult NearTuple(
+ const Literal& expected, const Literal& actual,
+ const ErrorSpec& error) TF_MUST_USE_RESULT;
+
+ // Expects that the expected and actual values are near.
+ static void ExpectNearTuple(const Literal& expected, const Literal& actual,
+ const ErrorSpec& error);
+
+ // Returns a multi-dimensional index as a string. For example: '{7, 8}' will
+ // be returned for a 2-dimensional index with dimension 0 index equal to 7,
+ // dimension 1 equal to 8.
+ static string MultiIndexAsString(
+ tensorflow::gtl::ArraySlice<int64> multi_index);
+
+ // Creates a literal with a new shape with the given new dimensions using the
+ // data in the given input literal. For reshaping purposes the (flat) data
+ // buffer of the input literal is assumed to have the given minor_to_major
+ // layout order.
+ static std::unique_ptr<Literal> Reshape(
+ tensorflow::gtl::ArraySlice<int64> new_dimensions,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major,
+ const Literal& literal);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(LiteralTestUtil);
+};
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected,
+ const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR0<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR1Equal(
+ tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR1<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2Equal(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR2<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3Equal(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR3<NativeT>(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2EqualArray2D(
+ const Array2D<NativeT>& expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR2FromArray2D(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3EqualArray3D(
+ const Array3D<NativeT>& expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR3FromArray3D(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR4EqualArray4D(
+ const Array4D<NativeT>& expected, const Literal& actual) {
+ ExpectEqual(*LiteralUtil::CreateR4FromArray4D(expected), actual);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected,
+ const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR0<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR1Near(
+ tensorflow::gtl::ArraySlice<NativeT> expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR1<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2Near(
+ std::initializer_list<std::initializer_list<NativeT>> expected,
+ const Literal& actual, const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR2<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3Near(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ expected,
+ const Literal& actual, const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR3<NativeT>(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR2NearArray2D(
+ const Array2D<NativeT>& expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR2FromArray2D(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR3NearArray3D(
+ const Array3D<NativeT>& expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR3FromArray3D(expected), actual, error);
+}
+
+template <typename NativeT>
+/* static */ void LiteralTestUtil::ExpectR4NearArray4D(
+ const Array4D<NativeT>& expected, const Literal& actual,
+ const ErrorSpec& error) {
+ ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error);
+}
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_LITERAL_TEST_UTIL_H_
diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
new file mode 100644
index 0000000000..fdec11c0e9
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that our utility functions for dealing with literals are correctly
+// implemented.
+
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) {
+ std::unique_ptr<Literal> literal = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
+ });
+ LiteralTestUtil::ExpectEqual(*literal, *literal);
+}
+
+TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) {
+ // Implementation note: we have to use a death test here, because you can't
+ // un-fail an assertion failure. The CHECK-failure is death, so we can make a
+ // death assertion.
+ auto unequal_things_are_equal = [] {
+ std::unique_ptr<Literal> lhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(42).get(),
+ LiteralUtil::CreateR0<int32>(64).get(),
+ });
+ std::unique_ptr<Literal> rhs = LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR0<int32>(64).get(),
+ LiteralUtil::CreateR0<int32>(42).get(),
+ });
+ CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal";
+ };
+ ASSERT_DEATH(unequal_things_are_equal(), "LHS and RHS are unequal");
+}
+
+TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) {
+ auto dummy_lambda = [] {
+ auto two = LiteralUtil::CreateR0<float>(2);
+ auto four = LiteralUtil::CreateR0<float>(4);
+ ErrorSpec error(0.001);
+ CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four";
+ };
+
+ tensorflow::Env* env = tensorflow::Env::Default();
+ string pattern =
+ tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/tempfile-*");
+ std::vector<string> files;
+ TF_CHECK_OK(env->GetMatchingPaths(pattern, &files));
+ for (const auto& f : files) {
+ TF_CHECK_OK(env->DeleteFile(f)) << f;
+ }
+
+ ASSERT_DEATH(dummy_lambda(), "two is not near four");
+
+ // Now check we wrote temporary files to the temporary directory that we can
+ // read.
+ std::vector<string> results;
+ TF_CHECK_OK(env->GetMatchingPaths(pattern, &results));
+
+ LOG(INFO) << "results: [" << tensorflow::str_util::Join(results, ", ") << "]";
+ EXPECT_EQ(3, results.size());
+ for (const string& result : results) {
+ Literal literal;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result,
+ &literal));
+ if (result.find("expected") != string::npos) {
+ EXPECT_EQ("2", LiteralUtil::ToString(literal));
+ } else if (result.find("actual") != string::npos) {
+ EXPECT_EQ("4", LiteralUtil::ToString(literal));
+ } else if (result.find("miscompares") != string::npos) {
+ EXPECT_EQ("true", LiteralUtil::ToString(literal));
+ } else {
+ FAIL() << "unknown file in temporary directory: " << result;
+ }
+ }
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
new file mode 100644
index 0000000000..591fff338c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/executable_run_options.h"
+#include "tensorflow/core/platform/dynamic_annotations.h"
+#include "tensorflow/core/platform/test.h"
+
+class LocalClientAotTest : public ::testing::Test {};
+
+// This is a compiled XLA computation which calls SumStructElements, and then
+// doubles the result.
+extern "C" void SumAndDouble(float* out, xla::ExecutableRunOptions* options,
+ void** parameters, void** temporary_buffers);
+
+// Just some structs with some arbitrary fields used to test the OPAQUE type.
+struct OpaqueData {
+ int field1 : 15;
+ int field2 : 14;
+ int field3 : 3;
+};
+
+// This is the implementation of a custom op which will be called by
+// SumAndDouble.
+extern "C" void SumStructElements(float* out, void** parameters) {
+ TF_ANNOTATE_MEMORY_IS_INITIALIZED(parameters, sizeof(OpaqueData*));
+ const auto* opaque_data = static_cast<OpaqueData*>(parameters[0]);
+ *out = opaque_data->field1 + opaque_data->field2 + opaque_data->field3;
+}
+
+TEST_F(LocalClientAotTest, Constant) {
+ xla::ExecutableRunOptions run_options;
+ OpaqueData opaque_data{100, 20, 3};
+ void* parameters[] = {&opaque_data};
+ float out = 0;
+ float tmp = 0;
+ void* temporary_buffers[] = {&out, &tmp, nullptr};
+ SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ EXPECT_EQ(out, 246.0f);
+
+ opaque_data = {1, 2, 3};
+ SumAndDouble(&out, &run_options, parameters, temporary_buffers);
+ EXPECT_EQ(out, 12.0f);
+}
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
new file mode 100644
index 0000000000..50e5dec0f6
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This program compiles an XLA program which computes 123 and writes the
+// resulting object file to stdout.
+
+#include <iostream>
+#include <vector>
+
+#include "external/llvm/include/llvm/ADT/Triple.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+using xla::string;
+
+xla::Computation Doubler(xla::Client* client) {
+ xla::ComputationBuilder builder(client, "doubler");
+ auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
+ auto x = builder.Parameter(0, r0f32, "x");
+ builder.Mul(x, builder.ConstantR0<float>(2.0));
+ return std::move(builder.Build().ValueOrDie());
+}
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ auto client = xla::ClientLibrary::LocalClientOrDie();
+
+ xla::ComputationBuilder builder(client, "aot_test_helper");
+ auto opaque_shape = xla::ShapeUtil::MakeOpaqueShape();
+ auto opaque_param = builder.Parameter(0, opaque_shape, "x");
+ auto r0f32 = xla::ShapeUtil::MakeShape(xla::F32, {});
+ auto sum = builder.CustomCall("SumStructElements", {opaque_param}, r0f32);
+ builder.Call(Doubler(client), {sum});
+
+ if (argc != 2) {
+ LOG(FATAL) << "local_client_aot_test_helper TARGET_CPU";
+ }
+
+ string triple_string;
+ string target_cpu = argv[1];
+ if (target_cpu == "k8") {
+ triple_string = "x86_64-none-linux-gnu";
+ } else if (target_cpu == "darwin") {
+ triple_string = "x86_64-apple-macosx";
+ } else if (target_cpu == "arm") {
+ triple_string = "aarch64-none-linux-gnu";
+ } else if (target_cpu == "ppc") {
+ triple_string = "powerpc64le-unknown-linux-gnu";
+ } else if (target_cpu == "local") {
+ triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple());
+ } else {
+ LOG(FATAL) << "unsupported TARGET_CPU: " << target_cpu;
+ }
+
+ llvm::Triple triple(xla::llvm_ir::AsStringRef(triple_string));
+
+ xla::cpu::CpuAotCompilationOptions options(
+ triple_string,
+ /*cpu_name=*/"", /*features=*/"", "SumAndDouble",
+ xla::cpu::CpuAotCompilationOptions::RelocationModel::Static);
+ auto result = xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
+ client
+ ->CompileAheadOfTime(builder.Build().ValueOrDie(),
+ /*argument_layouts=*/{&opaque_shape}, r0f32,
+ options)
+ .ConsumeValueOrDie());
+ // We should have two buffers, one for the result and one temporary buffer,
+ // and both should be float-sized. It's lame to hard-code this, but we need
+ // local_client_aot_test.cc to be able to easily invoke the function.
+ CHECK_EQ(result->result_buffer_index(), 0);
+ CHECK_EQ(result->buffer_sizes().size(), 3);
+ CHECK_EQ(result->buffer_sizes()[0], sizeof(float)); // result buffer
+ CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // temp buffer
+ CHECK_EQ(result->buffer_sizes()[2], -1);
+ if (triple.isOSBinFormatELF()) {
+ // Check the ELF magic.
+ CHECK_EQ(result->object_file_data()[0], 0x7F);
+ CHECK_EQ(result->object_file_data()[1], 'E');
+ CHECK_EQ(result->object_file_data()[2], 'L');
+ CHECK_EQ(result->object_file_data()[3], 'F');
+ // Check the ELF class.
+ CHECK_EQ(result->object_file_data()[4], triple.isArch32Bit() ? 1 : 2);
+ // Check the ELF endianness: it should be little.
+ CHECK_EQ(result->object_file_data()[5], triple.isLittleEndian() ? 1 : 2);
+ // Check the ELF version: it should be 1.
+ CHECK_EQ(result->object_file_data()[6], 1);
+ }
+
+ const std::vector<char>& object_file_data = result->object_file_data();
+ std::cout.write(object_file_data.data(), object_file_data.size());
+
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
new file mode 100644
index 0000000000..5c32ed8895
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+/* static */ TestAllocator* LocalClientTestBase::allocator_;
+
+StatusOr<perftools::gputools::DeviceMemoryBase> TestAllocator::Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) {
+ VLOG(2) << "Allocate(" << device_ordinal << ", " << size << ")";
+ {
+ tensorflow::mutex_lock lock(count_mutex_);
+ allocation_count_++;
+ device_allocation_count_[device_ordinal]++;
+ }
+ return StreamExecutorMemoryAllocator::Allocate(device_ordinal, size);
+}
+
+tensorflow::Status TestAllocator::Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) {
+ VLOG(2) << "Deallocate(" << device_ordinal << ")";
+ {
+ tensorflow::mutex_lock lock(count_mutex_);
+ deallocation_count_++;
+ device_deallocation_count_[device_ordinal]++;
+ }
+ return StreamExecutorMemoryAllocator::Deallocate(device_ordinal, mem);
+}
+
+int64 TestAllocator::allocation_count() const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ return allocation_count_;
+}
+
+int64 TestAllocator::allocation_count(int device_ordinal) const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ auto it = device_allocation_count_.find(device_ordinal);
+ if (it == device_allocation_count_.end()) {
+ return 0;
+ } else {
+ return it->second;
+ }
+}
+
+int64 TestAllocator::deallocation_count() const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ return deallocation_count_;
+}
+
+int64 TestAllocator::deallocation_count(int device_ordinal) const {
+ tensorflow::mutex_lock lock(count_mutex_);
+ auto it = device_deallocation_count_.find(device_ordinal);
+ if (it == device_deallocation_count_.end()) {
+ return 0;
+ } else {
+ return it->second;
+ }
+}
+
+/* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator(
+ perftools::gputools::Platform* platform) {
+ if (allocator_ == nullptr) {
+ allocator_ = new TestAllocator(
+ platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie()
+ : platform);
+ }
+ return allocator_;
+}
+
+LocalClientTestBase::LocalClientTestBase(
+ perftools::gputools::Platform* platform)
+ : local_client_(
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie()) {
+ stream_executor_ = PlatformUtil::GetStreamExecutors(local_client_->platform())
+ .ValueOrDie()[local_client_->default_device_ordinal()];
+ transfer_manager_ =
+ TransferManager::GetForPlatform(local_client_->platform()).ValueOrDie();
+}
+
+std::unique_ptr<ScopedShapedBuffer>
+LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) {
+ return LiteralToScopedShapedBuffer(literal,
+ local_client_->default_device_ordinal());
+}
+
+std::unique_ptr<ScopedShapedBuffer>
+LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal,
+ int device_ordinal) {
+ CHECK(!ShapeUtil::IsTuple(literal.shape()));
+ auto scoped_buffer =
+ ScopedShapedBuffer::MakeScopedShapedBuffer(
+ literal.shape(), GetOrCreateAllocator(local_client_->platform()),
+ device_ordinal)
+ .ConsumeValueOrDie();
+ // The creation of the scoped shaped buffer should allocate the buffer.
+ CHECK(!scoped_buffer->buffer(/*index=*/{}).is_null() ||
+ ShapeUtil::HasZeroElements(literal.shape()));
+ TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice(
+ stream_executor_, literal, scoped_buffer->mutable_buffer(/*index=*/{})));
+ return scoped_buffer;
+}
+
+void LocalClientTestBase::CopyShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer, ShapeIndex* index, Literal* literal) {
+ const Shape& shape = ShapeUtil::GetSubshape(shaped_buffer.shape(), *index);
+ if (ShapeUtil::IsTuple(shape)) {
+ *literal->mutable_shape() = shape;
+ for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
+ Literal* element_literal = literal->add_tuple_literals();
+ index->push_back(i);
+ CopyShapedBufferToLiteral(shaped_buffer, index, element_literal);
+ index->pop_back();
+ }
+ } else {
+ ASSERT_IS_OK(transfer_manager_->TransferLiteralFromDevice(
+ stream_executor_, shaped_buffer.buffer(*index), shape, shape, literal));
+ }
+}
+
+std::unique_ptr<Literal> LocalClientTestBase::ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer) {
+ auto literal = MakeUnique<Literal>();
+ ShapeIndex index;
+ CopyShapedBufferToLiteral(shaped_buffer, &index, literal.get());
+ return literal;
+}
+
+std::unique_ptr<ScopedShapedBuffer>
+LocalClientTestBase::ShapedBufferToScopedShapedBuffer(
+ std::unique_ptr<ShapedBuffer> shaped_buffer,
+ DeviceMemoryAllocator* allocator) {
+ std::unique_ptr<ScopedShapedBuffer> scoped_buffer =
+ ScopedShapedBuffer::MakeScopedShapedBuffer(
+ shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())
+ .ConsumeValueOrDie();
+ // Deallocate the existing DeviceMemoryBase values in the newly created scoped
+ // buffer and replace them with the values from the shaped buffer.
+ for (perftools::gputools::DeviceMemoryBase& memory_base :
+ *scoped_buffer->mutable_buffers()) {
+ TF_CHECK_OK(
+ allocator->Deallocate(shaped_buffer->device_ordinal(), &memory_base));
+ }
+ *scoped_buffer->mutable_buffers() = shaped_buffer->buffers();
+
+ TF_CHECK_OK(
+ scoped_buffer->mutable_shape_index_to_buffer_entry()
+ ->ForEachMutableElement(
+ [&shaped_buffer](const ShapeIndex& index, bool is_leaf,
+ size_t* buffer_entry) -> ::tensorflow::Status {
+ if (is_leaf) {
+ *buffer_entry =
+ shaped_buffer->shape_index_to_buffer_entry().element(
+ index);
+ }
+ return tensorflow::Status::OK();
+ }));
+ return scoped_buffer;
+}
+
+LocalExecuteOptions LocalClientTestBase::DefaultLocalExecuteOptions() const {
+ return LocalExecuteOptions().set_allocator(
+ GetOrCreateAllocator(local_client_->platform()));
+}
+
+std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
+ return ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions());
+}
+
+std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options) {
+ return ShapedBufferToScopedShapedBuffer(
+ local_client_->ExecuteLocally(computation, arguments, options)
+ .ConsumeValueOrDie(),
+ options.allocator());
+}
+
+void LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result) {
+ ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions(), result);
+}
+
+void LocalClientTestBase::ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result) {
+ ASSERT_IS_OK(
+ local_client_->ExecuteLocally(computation, arguments, options, result));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
new file mode 100644
index 0000000000..62916d50e3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -0,0 +1,146 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/local_service.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/service/shaped_buffer.h"
+#include "tensorflow/compiler/xla/service/transfer_manager.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+class TestAllocator : public StreamExecutorMemoryAllocator {
+ public:
+ explicit TestAllocator(perftools::gputools::Platform* platform)
+ : StreamExecutorMemoryAllocator(
+ platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
+ }
+
+ StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
+ int device_ordinal, uint64 size, bool retry_on_failure) override;
+ tensorflow::Status Deallocate(
+ int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) override;
+
+ // Return the number of allocations that have been performed.
+ int64 allocation_count() const;
+ int64 allocation_count(int device_ordinal) const;
+
+ // Return the number of deallocations that have been performed.
+ int64 deallocation_count() const;
+ int64 deallocation_count(int device_ordinal) const;
+
+ private:
+ mutable tensorflow::mutex count_mutex_;
+
+ // Global counts of allocations and deallocations.
+ int64 allocation_count_ GUARDED_BY(count_mutex_) = 0;
+ int64 deallocation_count_ GUARDED_BY(count_mutex_) = 0;
+
+ // Per-device counts of allocations and deallocations.
+ std::map<int, int64> device_allocation_count_ GUARDED_BY(count_mutex_);
+ std::map<int, int64> device_deallocation_count_ GUARDED_BY(count_mutex_);
+};
+
+// A base class for tests which exercise the LocalClient interface.
+class LocalClientTestBase : public ::testing::Test {
+ protected:
+ explicit LocalClientTestBase(
+ perftools::gputools::Platform* platform = nullptr);
+
+ static TestAllocator* GetOrCreateAllocator(
+ perftools::gputools::Platform* platform);
+
+ // Copy the given literal onto the default device and return a
+ // ScopedShapedBuffer.
+ std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer(
+ const Literal& literal);
+ // As above, but copy to a specific device.
+ std::unique_ptr<ScopedShapedBuffer> LiteralToScopedShapedBuffer(
+ const Literal& literal, int device_ordinal);
+
+ // Construct and return a literal containing the array represented by
+ // shaped_buffer.
+ std::unique_ptr<Literal> ShapedBufferToLiteral(
+ const ShapedBuffer& shaped_buffer);
+
+ // Helper for converting a ShapedBuffer into a literal.
+ void CopyShapedBufferToLiteral(const ShapedBuffer& shaped_buffer,
+ ShapeIndex* index, Literal* literal);
+
+ // Execute the given computation on the local client. With and without
+ // options.
+ std::unique_ptr<ScopedShapedBuffer> ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
+ std::unique_ptr<ScopedShapedBuffer> ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options);
+
+ // Returns a default set of execute options, configured to use allocator_
+ // as the allocator.
+ LocalExecuteOptions DefaultLocalExecuteOptions() const;
+
+ // Overloads which write result into the given buffer.
+ void ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ ShapedBuffer* result);
+ void ExecuteLocally(
+ const Computation& computation,
+ tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+ const LocalExecuteOptions& options, ShapedBuffer* result);
+
+ // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are
+ // deallocated when the object is destructed.
+ std::unique_ptr<ScopedShapedBuffer> ShapedBufferToScopedShapedBuffer(
+ std::unique_ptr<ShapedBuffer> shaped_buffer,
+ DeviceMemoryAllocator* allocator);
+
+ string TestName() const {
+ return ::testing::UnitTest::GetInstance()->current_test_info()->name();
+ }
+
+ // The allocator must live as long as the service which lives until the end of
+ // the process, so make the allocator static.
+ static TestAllocator* allocator_;
+
+ perftools::gputools::StreamExecutor* stream_executor_;
+ TransferManager* transfer_manager_;
+
+ LocalClient* local_client_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc
new file mode 100644
index 0000000000..b520d89de3
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/log_test.cc
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class LogTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(LogTest, LogZeroValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR3FromArray3D<float>(Array3D<float>(3, 0, 0));
+ builder.Log(x);
+
+ ComputeAndCompareR3<float>(&builder, Array3D<float>(3, 0, 0), {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(LogTest, LogTenValues) {
+ std::vector<float> input = {-0.0, 1.0, 2.0, -3.0, -4.0,
+ 5.0, 6.0, -7.0, -8.0, 9.0};
+
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(input);
+ builder.Log(x);
+
+ std::vector<float> expected;
+ for (float f : input) {
+ expected.push_back(std::log(f));
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
new file mode 100644
index 0000000000..014417a205
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -0,0 +1,589 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla.pb.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class MapTest : public ClientLibraryTestBase {
+ public:
+ explicit MapTest(perftools::gputools::Platform* platform = nullptr)
+ : ClientLibraryTestBase(platform,
+ /*disabled_pass_names=*/{"algsimp", "inline"}) {}
+
+ // Creates a function that adds its scalar argument with the constant 1.0.
+ //
+ // x {R0F32} ----> (add)
+ // /
+ // 1.0f ---------/
+ Computation CreateAdderToOne() {
+ ComputationBuilder mapped_builder(client_, TestName());
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = mapped_builder.ConstantR0<float>(1.0);
+ auto adder_to_one = mapped_builder.Add(x, one);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ Computation CreateMax() {
+ ComputationBuilder b(client_, TestName());
+ auto lhs = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto rhs = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ b.Max(lhs, rhs);
+ auto computation_status = b.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a computation that accepts an F32 and returns T(1) (ignoring the
+ // argument).
+ template <class T>
+ Computation CreateScalarOne() {
+ ComputationBuilder mapped_builder(client_, "scalar_one");
+ (void)mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ mapped_builder.ConstantR0<T>(1);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that multiplies its scalar argument by the constant 2.0
+ //
+ // x {R0F32} ----> (mul)
+ // /
+ // 2.0f ---------/
+ Computation CreateMulByTwo() {
+ ComputationBuilder mapped_builder(client_, TestName());
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto two = mapped_builder.ConstantR0<float>(2.0);
+ auto mul_by_two = mapped_builder.Mul(x, two);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that adds its scalar argument with the constant 1.0 and
+ // then multiplies by the original element.
+ //
+ // /---------------\
+ // / \
+ // x {R0F32} ----> (add) ----> (mul)
+ // /
+ // 1.0f ---------/
+ Computation CreateAdderToOneTimesItself() {
+ ComputationBuilder mapped_builder(client_, TestName());
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto one = mapped_builder.ConstantR0<float>(1.0);
+ auto adder_to_one = mapped_builder.Add(x, one);
+ auto result = mapped_builder.Mul(x, adder_to_one);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that takes a single parameter and calls map with
+ // "embedded_computation" on it, and then adds "n" to the result.
+ //
+ // x {R0F32} -----------> (map) ----> (add)
+ // / /
+ // embedded_computation --/ n --/
+ Computation CreateMapPlusN(const Computation& embedded_computation, float n) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto map = builder.Map({x}, embedded_computation);
+ auto constant_n = builder.ConstantR0<float>(n);
+ auto add = builder.Add(map, constant_n);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a binary function with signature (F32, F32) -> Pred
+ // defined by (x, y) -> x > y.
+ Computation CreateGt() {
+ ComputationBuilder b(client_, "Gt");
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = b.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto gt = b.Gt(x, y);
+ auto computation_status = b.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ // Creates a function that adds three scalar arguments
+ //
+ // x {R0F32} ----\
+ // \
+ // y {R0F32} ----> (add) ---> (add)
+ // /
+ // z {R0F32} ---------------/
+ Computation CreateTernaryAdder() {
+ ComputationBuilder mapped_builder(client_, "TernaryAdder");
+ auto x = mapped_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = mapped_builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ auto z = mapped_builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "z");
+ auto xy = mapped_builder.Add(x, y);
+ auto xyz = mapped_builder.Add(xy, z);
+ auto computation_status = mapped_builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+};
+
+TEST_F(MapTest, MapEachElemPlusOneR0) {
+ // Applies lambda (x) (+ x 1)) to an input scalar.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(42.0);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ ComputeAndCompareR0<float>(&builder, 43.0, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapEachElemPlusOneR1S4) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ ComputeAndCompareR1<float>(&builder, {3.2f, 4.3f, 5.4f, 6.5f},
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapEachF32ElementToS32Constant) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateScalarOne<int32>());
+
+ ComputeAndCompareR1<int32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
+}
+
+TEST_F(MapTest, MapEachF32ElementToU32Constant) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateScalarOne<uint32>());
+
+ ComputeAndCompareR1<uint32>(&builder, {1, 1, 1, 1}, {param0_data.get()});
+}
+
+TEST_F(MapTest, MapEachElemLongerChainR1) {
+ // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOneTimesItself());
+
+ ComputeAndCompareR1<float>(
+ &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f},
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(MapTest, MapMultipleMapsR1S0) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then
+ // maps (lambda (x) (* x 2)) on the result.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map1 = builder.Map({param}, CreateAdderToOne());
+ auto map2 = builder.Map({map1}, CreateMulByTwo());
+
+ ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapMultipleMapsR1S4) {
+ // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4, and then
+ // maps (lambda (x) (* x 2)) on the result.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map1 = builder.Map({param}, CreateAdderToOne());
+ auto map2 = builder.Map({map1}, CreateMulByTwo());
+
+ ComputeAndCompareR1<float>(&builder, {6.4f, 8.6f, 10.8f, 13.0f},
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapEachElemPlusOneR2) {
+ // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto param = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto map = builder.Map({param}, CreateAdderToOne());
+
+ Array2D<float> expected_array(
+ {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(MapTest, ComplexNestedMaps) {
+ // Constructs a complex graph of embedded computations to test the computation
+ // lowering order. Python equivalent:
+ //
+ // embed1 = lambda x: x + 1 # x + 1
+ // embed2 = lambda x: embed1(x) + 2 # x + 3
+ // embed3 = lambda x: embed1(x) + 4 # x + 5
+ // embed4 = lambda x: embed2(x) + embed3(x) # 2x + 8
+ // embed5 = lambda x: embed2(x) + 6 # x + 9
+ // result = embed5(42) + embed4(7) # (42 + 9) + (2 * 7 + 8) = 73
+
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+
+ auto embed1 = CreateAdderToOne();
+ auto embed2 = CreateMapPlusN(embed1, 2.0);
+ auto embed3 = CreateMapPlusN(embed1, 4.0);
+
+ ComputationBuilder embed4_builder(client_, "embed4");
+ auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x");
+ auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2);
+ auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3);
+ auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs);
+ auto embed4_status = embed4_builder.Build();
+ ASSERT_IS_OK(embed4_status.status());
+ auto embed4 = embed4_status.ConsumeValueOrDie();
+
+ auto embed5 = CreateMapPlusN(embed2, 6.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto constant_42 = builder.ConstantR0<float>(42.0);
+ auto constant_7 = builder.ConstantR0<float>(7.0);
+ auto map_42 = builder.Map({constant_42}, embed5);
+ auto map_7 = builder.Map({constant_7}, embed4);
+ builder.Add(map_42, map_7);
+
+ ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, VersionedEmbeddedComputation) {
+ // Build a computation X, use it in a map, then add an additional operation to
+ // computation X and use it again in a different map. Verify that the proper
+ // versions of computation X are used in each of the maps.
+
+ // Create a (embedded) computation which adds one to its parameter argument.
+ ComputationBuilder embedded_builder(client_, "EmbeddedComputation");
+ auto param_0 =
+ embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+ auto constant_one = embedded_builder.ConstantR0<float>(1.0);
+ auto adder_to_one = embedded_builder.Add(param_0, constant_one);
+ auto computation_status = embedded_builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ auto embedded_computation = computation_status.ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
+ auto map_plus_1 = builder.Map({constant_vector}, embedded_computation);
+
+ // Add another Add(1) operation to the existing embedded computation. This
+ // requires using the stub interface because the ComputationBuilder does not
+ // allow modification to the Computation objects after they have been built.
+ BinaryOpRequest request;
+ request.set_binop(BINOP_ADD);
+ *request.mutable_lhs() = adder_to_one;
+ *request.mutable_rhs() = constant_one;
+ OpRequest op_request;
+ *op_request.mutable_computation() = embedded_computation.handle();
+ *op_request.mutable_binary_op_request() = request;
+ OpResponse response;
+ tensorflow::Status s = client_->stub()->Op(&op_request, &response);
+ ASSERT_TRUE(s.ok());
+
+ auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation);
+
+ // The original vector has Add(1) applied to it with a map, followed by
+ // Add(1+1) resulting in a net Add(3).
+ ComputeAndCompareR1<float>(&builder, {4.0, 5.0, 6.0, 7.0}, {},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapBinaryAdder) {
+ // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map =
+ builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder));
+
+ ComputeAndCompareR1<float>(&builder, {7.3f, 7.7, 4.3f, 0},
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.01f));
+}
+
+// Adds two rank-2 arrays with different layouts. This test exercises a path
+// for Map that used to fail in shape inference (b/28989438).
+XLA_TEST_F(MapTest, AddWithMixedLayouts) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ test_utils::CreateR2LiteralWithLayout({{1, 2}, {3, 4}}, {1, 0});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ test_utils::CreateR2LiteralWithLayout({{10, 20}, {30, 40}}, {0, 1});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map =
+ builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder));
+
+ Array2D<int32> expected(2, 2);
+ expected(0, 0) = 11;
+ expected(0, 1) = 22;
+ expected(1, 0) = 33;
+ expected(1, 1) = 44;
+ ComputeAndCompareR2<int32>(&builder, expected,
+ {param0_data.get(), param1_data.get()});
+}
+
+XLA_TEST_F(MapTest, AddR3_3x0x2) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR3FromArray3D<int32>(Array3D<int32>(3, 0, 2));
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map =
+ builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder));
+
+ ComputeAndCompareR3<int32>(&builder, Array3D<int32>(3, 0, 2),
+ {param0_data.get(), param1_data.get()});
+}
+
+TEST_F(MapTest, MapTernaryAdder) {
+ // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors.
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param2_literal =
+ LiteralUtil::CreateR1<float>({-10.0f, -100.0f, -900.0f, -400.0f});
+ std::unique_ptr<GlobalData> param2_data =
+ client_->TransferToServer(*param2_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto param2 = builder.Parameter(2, param2_literal->shape(), "param2");
+ auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder());
+
+ ComputeAndCompareR1<float>(
+ &builder, {-2.7f, -92.3f, -895.7f, -400.0f},
+ {param0_data.get(), param1_data.get(), param2_data.get()},
+ ErrorSpec(0.01f));
+}
+
+TEST_F(MapTest, MapGt) {
+ // Maps (x,y) -> x > y onto two R1F32 vectors.
+ ComputationBuilder b(client_, TestName());
+ auto gt = CreateGt();
+ b.Map({b.ConstantR1<float>({1, 20}), b.ConstantR1<float>({10, 2})}, gt);
+ ComputeAndCompareR1<bool>(&b, {false, true}, {});
+}
+
+TEST_F(MapTest, NestedBinaryMap) {
+ Computation max_with_square;
+ {
+ // max_with_square(x) = do max(x, x^2) via a map.
+ ComputationBuilder b(client_, "max_with_square");
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ b.Map({x, b.Mul(x, x)}, CreateMax());
+ auto computation_status = b.Build();
+ ASSERT_IS_OK(computation_status.status());
+ max_with_square = computation_status.ConsumeValueOrDie();
+ }
+ ComputationBuilder b(client_, TestName());
+ auto input = b.ConstantR1<float>({0.1f, 0.5f, -0.5f, 1.0f, 2.0f});
+ b.Map({input}, max_with_square);
+ ComputeAndCompareR1<float>(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {});
+}
+
+TEST_F(MapTest, MapOperantionWithBuildError) {
+ // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors but uses an unsupported
+ // type combination (F32 + U16) to test that the error is reported to the
+ // outermost ComputationBuilder.
+ ComputationBuilder builder(client_, TestName());
+
+ auto sub_builder = builder.CreateSubBuilder("ErrorAdd");
+ auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(U16, {}), "y");
+ auto adder = sub_builder->Add(x, y);
+ auto error_add = sub_builder->BuildAndNoteError();
+
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 3.3f, 4.4f, 5.5f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<Literal> param1_literal =
+ LiteralUtil::CreateR1<float>({5.1f, 4.4f, -0.1f, -5.5f});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ auto map = builder.Map({param0, param1}, error_add);
+
+ StatusOr<Computation> computation_status = builder.Build();
+ ASSERT_TRUE(!computation_status.ok());
+ EXPECT_MATCH(computation_status.status().ToString(),
+ testing::HasSubstr("error from: ErrorAdd: binary op with "
+ "different element types: f32[] and u16[]"));
+}
+
+// MapTest disables inline and algsimp. MapTestWithFullOpt runs all
+// optimizations.
+using MapTestWithFullOpt = ClientLibraryTestBase;
+
+// Regression test for b/31466798. The inliner simplifies map(param0, param1,
+// power) to power(param0, param1) without deleting the old subcomputation which
+// is the same as the new entry computation. HloSubcomputationUnification used
+// to have issues with such patterns and maybe invalidate the pointer to entry
+// computation.
+TEST_F(MapTestWithFullOpt, MapScalarPower) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto sub_builder = builder.CreateSubBuilder("power");
+ auto x = sub_builder->Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = sub_builder->Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ sub_builder->Pow(x, y);
+ auto power = sub_builder->BuildAndNoteError();
+
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR0<float>(2.0f);
+ std::unique_ptr<Literal> param1_literal = LiteralUtil::CreateR0<float>(5.0f);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto param1 = builder.Parameter(1, param1_literal->shape(), "param1");
+ builder.Map({param0, param1}, power);
+
+ ComputeAndCompareR0<float>(&builder, 32.0f,
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.01f));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
new file mode 100644
index 0000000000..8aa4029440
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class MatOpsSimpleTest : public ClientLibraryTestBase {
+ protected:
+ Computation BuildSum() {
+ // sum(x, y) = x + y
+ ComputationBuilder builder(client_, "sum");
+ auto x_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto y_value =
+ builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y_value");
+ builder.Add(x_value, y_value);
+ auto computation_status = builder.Build();
+ TF_CHECK_OK(computation_status.status());
+ return computation_status.ConsumeValueOrDie();
+ }
+
+ void TestLinspaceMax(int64 rows, int64 cols) {
+ float from = -128.0, to = 256.0;
+ std::unique_ptr<Array2D<float>> alhs =
+ MakeLinspaceArray2D(from, to, rows, cols);
+ auto arhs = MakeUnique<Array2D<float>>(rows, cols, 1.0);
+
+ ComputationBuilder builder(
+ client_,
+ tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
+ auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
+ auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
+ auto max = builder.Max(lhs, rhs);
+
+ Array2D<float> aexpected(rows, cols);
+ for (int row = 0; row < rows; ++row) {
+ for (int col = 0; col < cols; ++col) {
+ aexpected(row, col) = std::max((*alhs)(row, col), (*arhs)(row, col));
+ }
+ }
+
+ ComputeAndCompareR2<float>(&builder, aexpected, {}, ErrorSpec(1e-6));
+ }
+};
+
+TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) {
+ ComputationBuilder builder(client_, "exp_2x2");
+ auto data = builder.ConstantR2<float>({
+ {1.0, 0.0}, // row 0
+ {-1.0, 0.5}, // row 1
+ });
+ builder.Exp(data);
+
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2<float>({{2.71828, 1.00000}, // row 0
+ {0.36788, 1.64872}}); // row 1
+
+ ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+}
+
+TEST_F(MatOpsSimpleTest, MapTwoByTwo) {
+ Computation add_half;
+ {
+ // add_half(x) = x + 0.5
+ ComputationBuilder builder(client_, "add_half");
+ auto x_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto half = builder.ConstantR0<float>(0.5);
+ builder.Add(x_value, half);
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ add_half = computation_status.ConsumeValueOrDie();
+ }
+
+ ComputationBuilder builder(client_, "map_2x2");
+ auto data = builder.ConstantR2<float>({
+ {1.0, 0.0}, // row 0
+ {-1.0, 0.5}, // row 1
+ });
+ auto map = builder.Map({data}, add_half);
+
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2<float>({{1.5, 0.5}, // row 0
+ {-0.5, 1.0}}); // row 1
+ ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5));
+}
+
+TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) {
+ ComputationBuilder builder(client_, "max_2x2");
+ auto lhs = builder.ConstantR2<float>({
+ {7.0, 2.0}, // row 0
+ {3.0, -4.0}, // row 1
+ });
+ auto rhs = builder.ConstantR2<float>({
+ {5.0, 6.0}, // row 0
+ {1.0, -8.0}, // row 1
+ });
+ auto max = builder.Max(lhs, rhs);
+
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2<float>({{7.0, 6.0}, // row 0
+ {3.0, -4.0}}); // row 1
+ ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6));
+}
+
+TEST_F(MatOpsSimpleTest, Max1x1Linspace) { TestLinspaceMax(1, 1); }
+
+TEST_F(MatOpsSimpleTest, Max2x2Linspace) { TestLinspaceMax(2, 2); }
+
+TEST_F(MatOpsSimpleTest, Max3x3Linspace) { TestLinspaceMax(3, 3); }
+
+TEST_F(MatOpsSimpleTest, Max4x4Linspace) { TestLinspaceMax(4, 4); }
+
+TEST_F(MatOpsSimpleTest, Max6x6Linspace) { TestLinspaceMax(6, 6); }
+
+TEST_F(MatOpsSimpleTest, Max8x8Linspace) { TestLinspaceMax(8, 8); }
+
+TEST_F(MatOpsSimpleTest, Max12x12Linspace) { TestLinspaceMax(12, 12); }
+
+TEST_F(MatOpsSimpleTest, Max16x16Linspace) { TestLinspaceMax(16, 16); }
+
+TEST_F(MatOpsSimpleTest, Max32x8Linspace) { TestLinspaceMax(32, 8); }
+
+TEST_F(MatOpsSimpleTest, Max64x8Linspace) { TestLinspaceMax(64, 8); }
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
new file mode 100644
index 0000000000..2cd680399b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -0,0 +1,74 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that slice operations can be performed.
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class SliceTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(SliceTest, Slice2D) {
+ ComputationBuilder builder(client_, "slice_2d");
+ auto original = builder.ConstantR2<float>(
+ {{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
+ builder.Slice(original, {2, 1}, {4, 3});
+
+ Array2D<float> expected({{8.0f, 9.0f}, {11.0f, 12.0f}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+XLA_TEST_F(SliceTest, Slice3D) {
+ ComputationBuilder builder(client_, "slice_3d");
+ Array3D<float> array_3d(
+ {{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
+ auto original = builder.ConstantR3FromArray3D<float>(array_3d);
+ builder.Slice(original, {0, 0, 1}, {2, 1, 2});
+
+ Array3D<float> expected_3d({{{2.0f}}, {{6.0f}}});
+ ComputeAndCompareR3<float>(&builder, expected_3d, {}, ErrorSpec(0.000001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc
new file mode 100644
index 0000000000..d3400b432f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/pad_test.cc
@@ -0,0 +1,420 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class PadTest : public ClientLibraryTestBase {
+ protected:
+ PadTest() {
+ // Initializes the padding configuration used for R4 tests.
+ // Pad only on the dimension 0 {low: 1, high: 0, interior: 2} and
+ // dimension 1 {low: 0, high: 2, interior: 1}.
+ auto dimension0 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension0->set_edge_padding_low(1);
+ dimension0->set_edge_padding_high(0);
+ dimension0->set_interior_padding(2);
+ auto dimension1 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension1->set_edge_padding_low(0);
+ dimension1->set_edge_padding_high(2);
+ dimension1->set_interior_padding(1);
+ auto dimension2 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension2->set_edge_padding_low(0);
+ dimension2->set_edge_padding_high(0);
+ dimension2->set_interior_padding(0);
+ auto dimension3 = r4_padding_on_dim0_dim1_.add_dimensions();
+ dimension3->set_edge_padding_low(0);
+ dimension3->set_edge_padding_high(0);
+ dimension3->set_interior_padding(0);
+ }
+
+ // Padding configuration for R4 that only pads dimension 0 and 1.
+ PaddingConfig r4_padding_on_dim0_dim1_;
+};
+
+// Tests a Pad() with a zero-element input and output.
+XLA_TEST_F(PadTest, Pad1DS0ToS0Array) {
+ ComputationBuilder b(client_, TestName());
+ // Set up the padding configuration {low: 0, high: 0, interior: 0}.
+ PaddingConfig padding_config;
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(0);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(0);
+
+ b.Pad(b.ConstantR1<float>({}), b.ConstantR0<float>(0.1), padding_config);
+ ComputeAndCompareR1<float>(&b, {}, {}, ErrorSpec(0.0001));
+}
+
+// Tests a Pad() with a zero-element input but a non-zero-element output.
+XLA_TEST_F(PadTest, Pad1DS0ToS5Array) {
+ ComputationBuilder b(client_, TestName());
+ // Set up the padding configuration {low: 3, high: 0, interior: 1}.
+ PaddingConfig padding_config;
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(1);
+ dimension->set_edge_padding_high(4);
+ dimension->set_interior_padding(7);
+
+ b.Pad(b.ConstantR1<float>({}), b.ConstantR0<float>(0.1), padding_config);
+ ComputeAndCompareR1<float>(&b, std::vector<float>(5, 0.1), {},
+ ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad1DS3Array) {
+ ComputationBuilder b(client_, TestName());
+ // Set up the padding configuration {low: 3, high: 0, interior: 1}.
+ PaddingConfig padding_config;
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(3);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(1);
+
+ b.Pad(b.ConstantR1<float>({1, 2, 3}), b.ConstantR0<float>(0.1),
+ padding_config);
+ std::vector<float> expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3});
+ ComputeAndCompareR1<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad4D_2x0x3x2_FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Pad(b.ConstantR4FromArray4D<float>(Array4D<float>(2, 0, 3, 2)),
+ b.ConstantR0<float>(1.5), r4_padding_on_dim0_dim1_);
+ ComputeAndCompareR4<float>(&b, Array4D<float>(5, 2, 3, 2, 1.5f), {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) {
+ ComputationBuilder b(client_, TestName());
+ auto input = MakeUnique<Array4D<float>>(1, 1, 3, 2);
+ Array2D<float> input_xy({
+ {1.0f, 2.0f}, // row 0
+ {3.0f, 4.0f}, // row 1
+ {5.0f, 6.0f}, // row 2
+ });
+ input->FillWithYX(input_xy);
+
+ b.Pad(b.ConstantR4FromArray4D<float>(*input), b.ConstantR0<float>(1.5),
+ r4_padding_on_dim0_dim1_);
+
+ auto expected = MakeUnique<Array4D<float>>(2, 3, 3, 2);
+ expected->Fill(1.5);
+ (*expected)(1, 0, 0, 0) = 1.0f;
+ (*expected)(1, 0, 0, 1) = 2.0f;
+ (*expected)(1, 0, 1, 0) = 3.0f;
+ (*expected)(1, 0, 1, 1) = 4.0f;
+ (*expected)(1, 0, 2, 0) = 5.0f;
+ (*expected)(1, 0, 2, 1) = 6.0f;
+ ComputeAndCompareR4<float>(&b, *expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) {
+ ComputationBuilder b(client_, TestName());
+
+ const float pad_value = 1.5f;
+ Array4D<float> input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6});
+ b.Pad(b.ConstantR4FromArray4D<float>(input), b.ConstantR0<float>(pad_value),
+ r4_padding_on_dim0_dim1_);
+
+ auto expected = MakeUnique<Array4D<float>>(8, 5, 1, 1);
+ expected->Fill(pad_value);
+ (*expected)(1, 0, 0, 0) = 1.0f;
+ (*expected)(1, 2, 0, 0) = 2.0f;
+ (*expected)(4, 0, 0, 0) = 3.0f;
+ (*expected)(4, 2, 0, 0) = 4.0f;
+ (*expected)(7, 0, 0, 0) = 5.0f;
+ (*expected)(7, 2, 0, 0) = 6.0f;
+ ComputeAndCompareR4<float>(&b, *expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) {
+ ComputationBuilder b(client_, TestName());
+
+ PaddingConfig padding_config;
+ auto dimension0 = padding_config.add_dimensions();
+ dimension0->set_edge_padding_low(0);
+ dimension0->set_edge_padding_high(0);
+ dimension0->set_interior_padding(0);
+ auto dimension1 = padding_config.add_dimensions();
+ dimension1->set_edge_padding_low(0);
+ dimension1->set_edge_padding_high(0);
+ dimension1->set_interior_padding(0);
+ auto dimension2 = padding_config.add_dimensions();
+ dimension2->set_edge_padding_low(2);
+ dimension2->set_edge_padding_high(1);
+ dimension2->set_interior_padding(0);
+ auto dimension3 = padding_config.add_dimensions();
+ dimension3->set_edge_padding_low(2);
+ dimension3->set_edge_padding_high(3);
+ dimension3->set_interior_padding(0);
+
+ const Layout layout = LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ const float pad_value = -5.123f;
+ Array4D<float> input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6});
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ input = LiteralUtil::Relayout(*input, layout);
+
+ b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config);
+
+ Array4D<float> expected_array(1, 1, 5, 8);
+ expected_array.Fill(pad_value);
+ expected_array(0, 0, 2, 2) = 1.0f;
+ expected_array(0, 0, 2, 3) = 2.0f;
+ expected_array(0, 0, 2, 4) = 3.0f;
+ expected_array(0, 0, 3, 2) = 4.0f;
+ expected_array(0, 0, 3, 3) = 5.0f;
+ expected_array(0, 0, 3, 4) = 6.0f;
+ ComputeAndCompareR4<float>(&b, expected_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) {
+ ComputationBuilder b(client_, TestName());
+
+ PaddingConfig padding_config;
+ auto dimension0 = padding_config.add_dimensions();
+ dimension0->set_edge_padding_low(0);
+ dimension0->set_edge_padding_high(0);
+ dimension0->set_interior_padding(0);
+ auto dimension1 = padding_config.add_dimensions();
+ dimension1->set_edge_padding_low(0);
+ dimension1->set_edge_padding_high(0);
+ dimension1->set_interior_padding(0);
+ auto dimension2 = padding_config.add_dimensions();
+ dimension2->set_edge_padding_low(2);
+ dimension2->set_edge_padding_high(2);
+ dimension2->set_interior_padding(1);
+ auto dimension3 = padding_config.add_dimensions();
+ dimension3->set_edge_padding_low(2);
+ dimension3->set_edge_padding_high(2);
+ dimension3->set_interior_padding(0);
+
+ const Layout layout = LayoutUtil::MakeLayout({0, 1, 2, 3});
+
+ const float pad_value = -5.123f;
+ Array4D<float> input_array(1, 25, 7, 7);
+ input_array.Fill(pad_value);
+ input_array(0, 0, 0, 0) = 1.0f;
+ input_array(0, 24, 6, 6) = 2.0f;
+ input_array(0, 17, 2, 5) = 3.0f;
+ auto input = LiteralUtil::CreateR4FromArray4D<float>(input_array);
+ input = LiteralUtil::Relayout(*input, layout);
+
+ b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config);
+
+ Array4D<float> expected_array(1, 25, 17, 11);
+ expected_array.Fill(pad_value);
+ expected_array(0, 0, 2, 2) = 1.0f;
+ expected_array(0, 24, 14, 8) = 2.0f;
+ expected_array(0, 17, 6, 7) = 3.0f;
+ ComputeAndCompareR4<float>(&b, expected_array, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(PadTest, Pad4DU8Array) {
+ ComputationBuilder b(client_, TestName());
+ auto input = MakeUnique<Array4D<uint8>>(1, 1, 3, 2);
+ Array2D<uint8> input_xy({
+ {1, 2}, // row 0
+ {3, 4}, // row 1
+ {5, 6}, // row 2
+ });
+ input->FillWithYX(input_xy);
+
+ b.Pad(b.ConstantR4FromArray4D<uint8>(*input), b.ConstantR0<uint8>(35),
+ r4_padding_on_dim0_dim1_);
+
+ auto expected = MakeUnique<Array4D<uint8>>(2, 3, 3, 2);
+ expected->Fill(35);
+ (*expected)(1, 0, 0, 0) = 1;
+ (*expected)(1, 0, 0, 1) = 2;
+ (*expected)(1, 0, 1, 0) = 3;
+ (*expected)(1, 0, 1, 1) = 4;
+ (*expected)(1, 0, 2, 0) = 5;
+ (*expected)(1, 0, 2, 1) = 6;
+ ComputeAndCompareR4<uint8>(&b, *expected, {});
+}
+
+XLA_TEST_F(PadTest, Pad4DPredArray) {
+ ComputationBuilder b(client_, TestName());
+
+ // Since bool is currently not well supported, use Broadcast operation to
+ // create the operand for Pad.
+ auto input = b.Broadcast(b.ConstantR0<bool>(true), {1, 1, 3, 2});
+ auto padded =
+ b.Pad(input, b.ConstantR0<bool>(false), r4_padding_on_dim0_dim1_);
+
+ // For the same reason, use Select to convert boolean values to int32.
+ auto zeros = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ auto ones = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ zeros->Fill(0);
+ ones->Fill(1);
+ b.Select(padded, b.ConstantR4FromArray4D<int32>(*ones),
+ b.ConstantR4FromArray4D<int32>(*zeros));
+
+ auto expected = MakeUnique<Array4D<int32>>(2, 3, 3, 2);
+ expected->Fill(0);
+ (*expected)(1, 0, 0, 0) = 1;
+ (*expected)(1, 0, 0, 1) = 1;
+ (*expected)(1, 0, 1, 0) = 1;
+ (*expected)(1, 0, 1, 1) = 1;
+ (*expected)(1, 0, 2, 0) = 1;
+ (*expected)(1, 0, 2, 1) = 1;
+ ComputeAndCompareR4<int32>(&b, *expected, {});
+}
+
+XLA_TEST_F(PadTest, Large2DPad) {
+ ComputationBuilder b(client_, TestName());
+
+ auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {4, 4}), "input");
+ PaddingConfig padding_config = MakeNoPaddingConfig(2);
+ for (int dim : {0, 1}) {
+ padding_config.mutable_dimensions(dim)->set_edge_padding_low(
+ 98 + 100 * (1 - dim));
+ padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 +
+ 100 * dim);
+ }
+ auto padded = b.Pad(input, b.ConstantR0<float>(0.0f), padding_config);
+
+ auto ones = MakeUnique<Array2D<float>>(4, 4);
+ ones->Fill(1.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D<float>(*ones);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f);
+ ComputeAndCompareR2<float>(&b, *expected, {input_data.get()});
+}
+
+XLA_TEST_F(PadTest, AllTypes2DPad) {
+ ComputationBuilder b(client_, TestName());
+
+ constexpr int64 in_rows = 35;
+ constexpr int64 in_cols = 35;
+ auto input =
+ b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input");
+ PaddingConfig padding_config = MakeNoPaddingConfig(2);
+ padding_config.mutable_dimensions(0)->set_edge_padding_low(7);
+ padding_config.mutable_dimensions(0)->set_edge_padding_high(5);
+ padding_config.mutable_dimensions(0)->set_interior_padding(3);
+ padding_config.mutable_dimensions(1)->set_edge_padding_low(6);
+ padding_config.mutable_dimensions(1)->set_edge_padding_high(4);
+ padding_config.mutable_dimensions(1)->set_interior_padding(2);
+ auto padded = b.Pad(input, b.ConstantR0<float>(3.14f), padding_config);
+
+ auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ operand->FillUnique(0.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f);
+ ComputeAndCompareR2<float>(&b, *expected, {input_data.get()},
+ ErrorSpec{0.0001});
+}
+
+XLA_TEST_F(PadTest, High2DPad) {
+ ComputationBuilder b(client_, TestName());
+
+ constexpr int64 in_rows = 129;
+ constexpr int64 in_cols = 129;
+ constexpr int64 low_padding = 0;
+ int64 high_padding[2] = {5, 7};
+ constexpr int64 interior_padding = 0;
+ auto input =
+ b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input");
+ PaddingConfig padding_config = MakeNoPaddingConfig(2);
+ for (int dim : {0, 1}) {
+ padding_config.mutable_dimensions(dim)->set_edge_padding_low(low_padding);
+ padding_config.mutable_dimensions(dim)->set_edge_padding_high(
+ high_padding[dim]);
+ padding_config.mutable_dimensions(dim)->set_interior_padding(
+ interior_padding);
+ }
+ auto padded = b.Pad(input, b.ConstantR0<float>(2.718f), padding_config);
+
+ auto operand = MakeUnique<Array2D<float>>(in_rows, in_cols);
+ operand->FillUnique(1.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D<float>(*operand);
+ auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputeAndCompareR2<float>(&b, *expected, {input_data.get()},
+ ErrorSpec(0.0001));
+}
+
+// Regression test for b/31827337.
+XLA_TEST_F(PadTest, ReducePad) {
+ ComputationBuilder b(client_, TestName());
+ auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "input");
+
+ Computation add_f32 = CreateScalarAddComputation(F32, &b);
+ auto reduce = b.Reduce(input, b.ConstantR0<float>(0.0), add_f32, {0});
+
+ PaddingConfig padding_config = MakeNoPaddingConfig(3);
+ padding_config.mutable_dimensions(0)->set_edge_padding_low(1);
+ padding_config.mutable_dimensions(0)->set_edge_padding_high(1);
+ auto pad = b.Pad(reduce, b.ConstantR0<float>(0.0), padding_config);
+
+ auto ones = MakeUnique<Array4D<float>>(2, 2, 2, 2);
+ ones->Fill(1.0);
+ auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(*ones);
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ Array3D<float> expected({{{0.0, 0.0}, {0.0, 0.0}},
+ {{2.0, 2.0}, {2.0, 2.0}},
+ {{2.0, 2.0}, {2.0, 2.0}},
+ {{0.0, 0.0}, {0.0, 0.0}}});
+ ComputeAndCompareR3<float>(&b, expected, {input_data.get()});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
new file mode 100644
index 0000000000..2f05576cee
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -0,0 +1,357 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ParamsTest : public ClientLibraryTestBase {};
+
+XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR0<float>(3.14159f);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
+
+ ComputeAndCompareR0<float>(&builder, 3.14159f, {param0_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1<float>({});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "param0");
+
+ ComputeAndCompareR1<float>(&builder, {}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({3.14f, -100.25f});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0");
+
+ ComputeAndCompareR1<float>(&builder, {3.14f, -100.25f}, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
+ ComputationBuilder builder(client_, TestName());
+ string str("hello world");
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR1U8(str);
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(
+ 0, ShapeUtil::MakeShape(U8, {static_cast<int64>(str.size())}), "param0");
+
+ ComputeAndCompareR1U8(&builder, str, {param0_data.get()});
+}
+
+XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 0}), "param0");
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0),
+ {param0_data.get()}, ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal = LiteralUtil::CreateR2<float>(
+ {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+
+ auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3, 2}), "param0");
+
+ Array2D<float> expected_array(
+ {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_array, {param0_data.get()},
+ ErrorSpec(0.01f));
+}
+
+XLA_TEST_F(ParamsTest, TwoParameters) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+
+ // Use both parameters
+ //
+ // {1, 2} + {10, 20} = {11, 22}
+ auto sum = builder.Add(param0, param1);
+ sum = builder.Add(param0, param1);
+
+ // Use only the second parameter again, to show that it can be used
+ // twice and to make the computation asymmetric in the two
+ // parameters to test that the parameters are not swapped.
+ //
+ // {11, 22} * {10, 20} = {110, 440}
+ auto prod = builder.Mul(sum, param1);
+
+ ComputeAndCompareR1<float>(&builder, {110, 440},
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, MissingParameter) {
+ // Test that an error is returned when a computation with an incomplete set of
+ // parameters (parameter numbers not contiguous from 0) is executed.
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<float>(3.14159f);
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ auto execute_status = client_->Execute(computation, {data.get(), data.get()},
+ /*output_layout=*/nullptr,
+ /*execution_profile=*/nullptr);
+ ASSERT_EQ(execute_status.status().code(),
+ tensorflow::error::FAILED_PRECONDITION);
+}
+
+XLA_TEST_F(ParamsTest, UnusedParameter) {
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*literal0).ConsumeValueOrDie();
+ auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+
+ std::unique_ptr<Literal> literal1 = LiteralUtil::CreateR1<float>({10, 20});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*literal1).ConsumeValueOrDie();
+ auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+
+ ComputeAndCompareR1<float>(&builder, {10, 20},
+ {param0_data.get(), param1_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
+ // Build a computation with a couple unused parameters which are used in an
+ // unused expression.
+ ComputationBuilder builder(client_, TestName());
+
+ std::unique_ptr<Literal> literal0 = LiteralUtil::CreateR1<float>({1, 2});
+ std::unique_ptr<GlobalData> param0_data =
+ client_->TransferToServer(*literal0).ConsumeValueOrDie();
+
+ std::unique_ptr<Literal> literal1 =
+ LiteralUtil::CreateR1<float>({10, 20, 30});
+ std::unique_ptr<GlobalData> param1_data =
+ client_->TransferToServer(*literal1).ConsumeValueOrDie();
+
+ auto param0 = builder.Parameter(0, literal0->shape(), "param0");
+ auto param1 = builder.Parameter(1, literal1->shape(), "param1");
+ auto param2 = builder.Parameter(2, literal1->shape(), "param2");
+
+ // This add is unused.
+ builder.Add(param1, param2);
+
+ builder.Neg(param0);
+
+ ComputeAndCompareR1<float>(
+ &builder, {-1, -2},
+ {param0_data.get(), param1_data.get(), param1_data.get()},
+ ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
+ ComputationBuilder builder(client_, TestName());
+ constexpr int size = 8 * 128 * 2;
+
+ std::vector<float> init_value = {{0, 1}};
+ init_value.resize(size);
+ ComputationDataHandle sum_handle = builder.ConstantR1<float>(init_value);
+ std::vector<float> sum = {{0, 1}};
+ sum.resize(size);
+
+ std::vector<std::unique_ptr<GlobalData>> param_data_owner;
+
+ constexpr int parameter_count = 100;
+ for (int i = 0; i < parameter_count; ++i) {
+ const float entry0 = i;
+ const float entry1 = 2 * i;
+ sum[0] += entry0;
+ sum[1] += entry1;
+
+ std::vector<float> sum_value = {{entry0, entry1}};
+ sum_value.resize(size);
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR1<float>(sum_value);
+ param_data_owner.push_back(
+ client_->TransferToServer(*literal).ConsumeValueOrDie());
+ ComputationDataHandle param =
+ builder.Parameter(i, literal->shape(), "param");
+ sum_handle = builder.Add(sum_handle, param);
+ }
+
+ std::vector<GlobalData*> param_data;
+ for (const std::unique_ptr<GlobalData>& data : param_data_owner) {
+ param_data.push_back(data.get());
+ }
+
+ ComputeAndCompareR1<float>(&builder, sum, param_data, ErrorSpec(0.0001f));
+}
+
+XLA_TEST_F(ParamsTest,
+ DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) {
+ ComputationBuilder builder(client_, TestName());
+
+ Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3});
+ auto input = builder.Parameter(0, tuple_shape, "input");
+ auto lhs = builder.GetTupleElement(input, 0);
+ auto rhs = builder.GetTupleElement(input, 1);
+ builder.Add(lhs, rhs);
+
+ std::unique_ptr<GlobalData> data =
+ client_
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::CreateR1<float>({1, 2, 3}).get(),
+ LiteralUtil::CreateR1<float>({4, 5, 6}).get(),
+ }))
+ .ConsumeValueOrDie();
+
+ std::vector<GlobalData*> arguments = {data.get()};
+ const std::vector<float> expected = {1 + 4, 2 + 5, 3 + 6};
+ ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
+}
+
+// Verifies that passing a 2x2 with {0, 1} layout returns the same value back
+// when (transferred to the server and) passed through a parameter.
+XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 2}, {3, 4},
+ });
+ *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, literal->shape(), "input");
+
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+}
+
+// As above, but for {1, 0} layout.
+XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 3}, {2, 4},
+ });
+ *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+ ComputationBuilder builder(client_, TestName());
+ builder.Parameter(0, literal->shape(), "input");
+
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ ComputeAndCompareLiteral(&builder, *literal, {data.get()}, ErrorSpec(1e-3));
+}
+
+XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR2<float>({
+ {1, 3}, {2, 4},
+ });
+ const Shape original = literal->shape();
+ {
+ // Reverse the layout present in original, and make that the layout of the
+ // literal.
+ std::vector<int64> original_layout(
+ original.layout().minor_to_major().begin(),
+ original.layout().minor_to_major().end());
+ std::reverse(original_layout.begin(), original_layout.end());
+ *literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout(original_layout);
+ ASSERT_EQ(2, LiteralUtil::Get<float>(*literal, {0, 1}));
+ }
+ // Use the original shape in building the computation.
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.Parameter(0, original, "input");
+ // Use the slice operator to get an off-diagonal element.
+ builder.Slice(input, {0, 1}, {1, 2});
+
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ // Check that we got the off-diagonal value that we expected.
+ Array2D<float> expected(1, 1);
+ expected(0, 0) = 2;
+ ComputeAndCompareR2(&builder, expected, {data.get()}, ErrorSpec(1e-3));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
new file mode 100644
index 0000000000..96393c41e8
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Miscellaneous tests with the PRED type that don't fit anywhere else.
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class PredTest : public ClientLibraryTestBase {
+ protected:
+ void TestCompare(bool lhs, bool rhs, bool expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<bool>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<bool>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<bool>(&builder, expected, {});
+ }
+};
+
+TEST_F(PredTest, ConstantR0PredTrue) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR0<bool>(true);
+ ComputeAndCompareR0<bool>(&builder, true, {});
+}
+
+TEST_F(PredTest, ConstantR0PredFalse) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR0<bool>(false);
+ ComputeAndCompareR0<bool>(&builder, false, {});
+}
+
+TEST_F(PredTest, ConstantR0PredCompareEq) {
+ TestCompare(true, false, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareNe) {
+ TestCompare(true, false, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareLe) {
+ TestCompare(true, false, false, &ComputationBuilder::Le);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareLt) {
+ TestCompare(true, false, false, &ComputationBuilder::Lt);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareGe) {
+ TestCompare(true, false, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(PredTest, ConstantR0PredCompareGt) {
+ TestCompare(true, false, true, &ComputationBuilder::Gt);
+}
+
+TEST_F(PredTest, ConstantR1Pred) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<bool>({true, false, false, true});
+ ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
+}
+
+TEST_F(PredTest, ConstantR2Pred) {
+ ComputationBuilder builder(client_, TestName());
+ auto a =
+ builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
+ const string expected = R"(pred[2,3] {
+ { 011 },
+ { 100 },
+})";
+ EXPECT_EQ(expected, ExecuteToString(&builder, {}));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
new file mode 100644
index 0000000000..8d77b3dd61
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -0,0 +1,238 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/primitive_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class PrngTest : public ClientLibraryTestBase {
+ protected:
+ template <typename T>
+ void UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims);
+ void BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims);
+};
+
+template <typename T>
+void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
+ ComputationBuilder builder(client_, TestName());
+ builder.RngUniform(
+ builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
+ ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
+
+ auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions()));
+ LiteralUtil::EachCell<T>(*actual,
+ [=](tensorflow::gtl::ArraySlice<int64>, T value) {
+ EXPECT_LE(a, value);
+ EXPECT_GE(b, value);
+ });
+}
+
+void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice<int64> dims) {
+ ComputationBuilder builder(client_, TestName());
+ auto shape = ShapeUtil::MakeShape(U32, dims);
+ builder.RngBernoulli(builder.ConstantR0<float>(p), shape);
+
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build());
+ constexpr uint64 kTestSeed = 42;
+ TF_ASSIGN_OR_ASSERT_OK(
+ auto actual,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/kTestSeed));
+ EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions()));
+ int32 sum = 0;
+ LiteralUtil::EachCell<uint32>(
+ *actual, [&sum](tensorflow::gtl::ArraySlice<int64>, uint32 value) {
+ EXPECT_TRUE(value == 0 || value == 1);
+ sum += value;
+ });
+ int32 total = ShapeUtil::ElementsIn(shape);
+ float p_tilde = sum / static_cast<float>(total);
+
+ // Test within expected range using normal approximation. The test uses a
+ // fixed seed and has a fixed output per p and backend. Using the normal
+ // approximation as this test is invoked for different `p` and the different
+ // backends could use different random number generators and produce different
+ // values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96.
+ float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total);
+ EXPECT_GE(p_tilde, p - normal_approximation_term);
+ EXPECT_LE(p_tilde, p + normal_approximation_term);
+}
+
+// Uniform random number generation tests
+XLA_TEST_F(PrngTest, ScalarU01) { UniformTest<float>(0, 1, {}); }
+XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest<float>(0, 1, {0}); }
+XLA_TEST_F(PrngTest, TenValuesU01) { UniformTest<float>(0, 1, {10}); }
+XLA_TEST_F(PrngTest, TenValuesU37) { UniformTest<float>(3, 7, {10}); }
+XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest<float>(0, 1, {0, 20}); }
+XLA_TEST_F(PrngTest, LargeU01) { UniformTest<float>(0, 1, {0x100, 0x100}); }
+XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
+
+XLA_TEST_F(PrngTest, MapUsingRng) {
+ // Build a x -> (x + U[0,1)) computation.
+ auto build_sum_rng = [this](ComputationBuilder& builder) {
+ auto b = builder.CreateSubBuilder("sum_with_rng");
+ auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input");
+ b->Add(x,
+ b->RngUniform(b->ConstantR0<float>(0), b->ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {})));
+ return b->BuildAndNoteError();
+ };
+
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> param0_literal =
+ LiteralUtil::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
+ TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr<GlobalData> param0_data,
+ client_->TransferToServer(*param0_literal));
+
+ auto param0 = builder.Parameter(0, param0_literal->shape(), "param0");
+ auto fn = build_sum_rng(builder);
+ builder.Map({param0}, fn);
+
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build());
+ TF_ASSIGN_OR_ASSERT_OK(
+ auto actual,
+ client_->ExecuteAndTransfer(computation,
+ /*arguments=*/{param0_data.get()}, nullptr,
+ nullptr, /*seed=*/125));
+ EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size());
+ for (int i = 0; i < param0_literal->f32s_size(); ++i) {
+ EXPECT_GE(actual->f32s(i), param0_literal->f32s(i));
+ EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f);
+ }
+}
+
+// This tests demonstrates the global seeding behaviour.
+// * If a seed is passed in via Execute (ExecuteAndTransfer) then the output is
+// fixed (i.e., there is a single output for a given seed);
+// * If no seed is passed in then the output of every call can be different;
+XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
+ // Build a U[0,1) computation.
+ auto build_computation = [this]() {
+ ComputationBuilder builder(client_, TestName());
+ builder.RngUniform(builder.ConstantR0<float>(0),
+ builder.ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {10}));
+ return builder.Build();
+ };
+
+ std::unique_ptr<Literal> result1;
+ {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation());
+ TF_ASSIGN_OR_ASSERT_OK(
+ result1,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/42));
+ }
+ std::unique_ptr<Literal> result2;
+ std::unique_ptr<Literal> result3;
+ {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation());
+ TF_ASSIGN_OR_ASSERT_OK(
+ result2,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/42));
+ TF_ASSIGN_OR_ASSERT_OK(
+ result3,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/42));
+ }
+
+ std::unique_ptr<Literal> result4;
+ std::unique_ptr<Literal> result5;
+ std::unique_ptr<Literal> result6;
+ {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, build_computation());
+ TF_ASSIGN_OR_ASSERT_OK(
+ result4,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr,
+ /*seed=*/65));
+ TF_ASSIGN_OR_ASSERT_OK(
+ result5,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr));
+ TF_ASSIGN_OR_ASSERT_OK(
+ result6,
+ client_->ExecuteAndTransfer(computation, /*arguments=*/{},
+ /*shape_with_output_layout=*/nullptr,
+ /*execution_profile=*/nullptr));
+ }
+
+ LiteralTestUtil::ExpectEqual(*result1, *result2);
+ LiteralTestUtil::ExpectEqual(*result1, *result3);
+ LiteralTestUtil::ExpectNotEqual(*result1, *result4);
+ LiteralTestUtil::ExpectNotEqual(*result4, *result5);
+ LiteralTestUtil::ExpectNotEqual(*result5, *result6);
+}
+
+// Bernoulli random number generation tests
+XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); }
+XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); }
+
+XLA_TEST_F(PrngTest, TenValuesN01) {
+ ComputationBuilder builder(client_, TestName());
+ builder.RngNormal(builder.ConstantR0<float>(0), builder.ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {10}));
+
+ ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ // TODO(b/25995601): Test that resultant values are reasonable
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
new file mode 100644
index 0000000000..eb7e63705b
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
@@ -0,0 +1,61 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class QueryInferredShapeTest : public ClientLibraryTestBase {};
+
+TEST_F(QueryInferredShapeTest, OnePlusOneShape) {
+ ComputationBuilder builder(client_, "one_plus_one");
+ auto one = builder.ConstantR0<float>(1.0);
+ auto result = builder.Add(one, one);
+ StatusOr<std::unique_ptr<Shape>> shape_status = builder.GetShape(result);
+ ASSERT_IS_OK(shape_status.status());
+ auto shape = shape_status.ConsumeValueOrDie();
+ ASSERT_TRUE(ShapeUtil::Equal(*shape, ShapeUtil::MakeShape(F32, {})));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
new file mode 100644
index 0000000000..f3d8da5c8c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -0,0 +1,506 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that multi-dimensional arrays can be reduced among various
+// user-provided dimensions.
+//
+// Note that comments for these tests are white-box in that they talk about the
+// default data layout.
+//
+// The test space for reductions is the cartesian product of:
+//
+// <possible ranks> x
+// <possible layouts for chosen rank> x
+// <possible subsets of dimensions in chosen rank>
+
+#include <stdlib.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReduceTest : public ClientLibraryTestBase {
+ protected:
+ ReduceTest() {
+ // Implementation note: layed out z >> y >> x by default.
+ // clang-format off
+ literal_2d_ = LiteralUtil::CreateR2<float>({
+ // x0 x1 x2
+ { 1.f, 2.f, 3.f}, // y0
+ { 4.f, 5.f, 6.f}, // y1
+ });
+ literal_3d_ = LiteralUtil::CreateR3Projected<float>({
+ // x0 x1 x2
+ { 1.f, 2.f, 3.f}, // y0
+ { 4.f, 5.f, 6.f}, // y1
+ }, 4);
+ // clang-format on
+ CHECK(ShapeUtil::Equal(
+ literal_3d_->shape(),
+ ShapeUtil::MakeShape(F32, {/*z=*/4, /*y=*/2, /*x=*/3})))
+ << literal_3d_->shape().ShortDebugString();
+ }
+
+ // Runs an R1 => R0 reduction test with the given number of elements.
+ void RunR1ToR0Test(int64 element_count) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ std::vector<float> input_data(element_count);
+ for (int64 i = 0; i < element_count; ++i) {
+ input_data[i] = rand_r(&seed_) % 3;
+ if (rand_r(&seed_) % 2 == 0) {
+ input_data[i] *= -1;
+ }
+ }
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR1(AsSlice(input_data));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ float expected = 0.0;
+ for (float item : input_data) {
+ expected += item;
+ }
+ ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.001));
+ }
+
+ // Runs an R2 => R0 reduction test with the given number of (rows, cols).
+ void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0, 1});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout({minor, major}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ float expected = 0.0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ for (int64 colno = 0; colno < cols; ++colno) {
+ expected += input_data(rowno, colno);
+ }
+ }
+ ComputeAndCompareR0<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+ }
+
+ // Runs an R2 => R1 reduction test with the given number of (rows, cols).
+ void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ builder.Reduce(input, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout({minor, major}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += input_data(rowno, colno);
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+ }
+
+ std::unique_ptr<Literal> literal_2d_;
+ std::unique_ptr<Literal> literal_3d_;
+ uint32 seed_ = 0xdeadbeef;
+};
+
+XLA_TEST_F(ReduceTest, ReduceR1_0_F32_To_R0) { RunR1ToR0Test(0); }
+XLA_TEST_F(ReduceTest, ReduceR1_1_F32_To_R0) { RunR1ToR0Test(1); }
+XLA_TEST_F(ReduceTest, ReduceR1_2_F32_To_R0) { RunR1ToR0Test(2); }
+XLA_TEST_F(ReduceTest, ReduceR1_16_F32_To_R0) { RunR1ToR0Test(16); }
+XLA_TEST_F(ReduceTest, ReduceR1_240_F32_To_R0) { RunR1ToR0Test(240); }
+XLA_TEST_F(ReduceTest, ReduceR1_128_F32_To_R0) { RunR1ToR0Test(128); }
+XLA_TEST_F(ReduceTest, ReduceR1_129_F32_To_R0) { RunR1ToR0Test(129); }
+XLA_TEST_F(ReduceTest, ReduceR1_256_F32_To_R0) { RunR1ToR0Test(256); }
+XLA_TEST_F(ReduceTest, ReduceR1_1024_F32_To_R0) { RunR1ToR0Test(1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_2048_F32_To_R0) { RunR1ToR0Test(2048); }
+XLA_TEST_F(ReduceTest, ReduceR1_16K_F32_To_R0) { RunR1ToR0Test(16 * 1024); }
+XLA_TEST_F(ReduceTest, ReduceR1_16KP1_F32_To_R0) {
+ RunR1ToR0Test(16 * 1024 + 1);
+}
+
+XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R0) { RunR2ToR0Test(0, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R0) { RunR2ToR0Test(0, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R0) { RunR2ToR0Test(1, 1); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R0) { RunR2ToR0Test(2, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R0) { RunR2ToR0Test(2, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R0) { RunR2ToR0Test(8, 8); }
+XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R0) { RunR2ToR0Test(9, 9); }
+XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R0) { RunR2ToR0Test(50, 111); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R0) { RunR2ToR0Test(111, 50); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R0) {
+ RunR2ToR0Test(111, 50, 0, 1);
+}
+XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R0) { RunR2ToR0Test(1024, 1024); }
+XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R0) { RunR2ToR0Test(1000, 1500); }
+
+// Disabled due to b/33245142. Failed on 2016-11-30.
+// XLA_TEST_F(ReduceTest, ReduceR2_0x0_To_R1) { RunR2ToR1Test(0, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_0x2_To_R1) { RunR2ToR1Test(0, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_1x1_To_R1) { RunR2ToR1Test(1, 1); }
+// Disabled due to b/33245142. Failed on 2016-11-30.
+// XLA_TEST_F(ReduceTest, ReduceR2_2x0_To_R1) { RunR2ToR1Test(2, 0); }
+XLA_TEST_F(ReduceTest, ReduceR2_2x2_To_R1) { RunR2ToR1Test(2, 2); }
+XLA_TEST_F(ReduceTest, ReduceR2_8x8_To_R1) { RunR2ToR1Test(8, 8); }
+XLA_TEST_F(ReduceTest, ReduceR2_9x9_To_R1) { RunR2ToR1Test(9, 9); }
+XLA_TEST_F(ReduceTest, ReduceR2_50x111_To_R1) { RunR2ToR1Test(50, 111); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_To_R1) { RunR2ToR1Test(111, 50); }
+XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) {
+ RunR2ToR1Test(111, 50, 0, 1);
+}
+XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); }
+XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); }
+
+XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
+ const int64 rows = 111, cols = 50;
+
+ ComputationBuilder builder(client_, TestName());
+ Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto log_ = builder.Log(input);
+ builder.Reduce(log_, zero, add_f32, /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> input_data(rows, cols);
+ input_data.FillRandom(3.14f, 0.04);
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR2FromArray2D(input_data);
+ input_literal =
+ LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1}));
+ std::unique_ptr<GlobalData> input_global_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ std::vector<float> expected;
+ for (int64 colno = 0; colno < cols; ++colno) {
+ float column_sum = 0;
+ for (int64 rowno = 0; rowno < rows; ++rowno) {
+ column_sum += log(input_data(rowno, colno));
+ }
+ expected.push_back(column_sum);
+ }
+ ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
+ ErrorSpec(0.01, 1e-4));
+}
+
+struct BoundsLayout {
+ std::vector<int64> bounds;
+ std::vector<int64> layout;
+ std::vector<int64> reduce_dims;
+};
+
+void PrintTo(const BoundsLayout& spec, std::ostream* os) {
+ *os << tensorflow::strings::Printf(
+ "R%luToR%lu%s_%s_Reduce%s", spec.bounds.size(),
+ spec.bounds.size() - spec.reduce_dims.size(),
+ tensorflow::str_util::Join(spec.bounds, "x").c_str(),
+ tensorflow::str_util::Join(spec.layout, "").c_str(),
+ tensorflow::str_util::Join(spec.reduce_dims, "").c_str());
+}
+
+// Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto add = CreateScalarAddComputation(F32, &builder);
+ auto scalar = builder.ConstantR0<float>(42.0);
+ auto broacasted = builder.Broadcast(scalar, {500, 500});
+ builder.Reduce(broacasted, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ float expected = 42.0f * static_cast<float>(500 * 500);
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto max = CreateScalarMaxComputation(F32, &builder);
+ auto scalar = builder.ConstantR0<float>(42.0);
+ auto broacasted = builder.Broadcast(scalar, {500, 500});
+ builder.Reduce(broacasted, builder.ConstantR0<float>(0.0f), max, {0, 1});
+
+ float expected = 42.0f;
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Max-reduces a matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto max = CreateScalarMaxComputation(F32, &builder);
+ Array2D<float> input(300, 250);
+ input.FillRandom(214.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ builder.Reduce(builder.ConstantLiteral(*input_literal),
+ builder.ConstantR0<float>(FLT_MIN), max, {0, 1});
+ auto input_max = FLT_MIN;
+ input.Each(
+ [&](int64, int64, float* v) { input_max = std::max(input_max, *v); });
+ ComputeAndCompareR0<float>(&builder, input_max, {}, ErrorSpec(0.0001));
+}
+
+// Min-reduces matrix among dimension 1 and 0.
+XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto min = CreateScalarMinComputation(F32, &builder);
+ Array2D<float> input(150, 130);
+ input.FillRandom(214.0f);
+ auto input_literal = LiteralUtil::CreateR2FromArray2D(input);
+ builder.Reduce(builder.ConstantLiteral(*input_literal),
+ builder.ConstantR0<float>(FLT_MAX), min, {0, 1});
+
+ auto input_min = FLT_MAX;
+ input.Each(
+ [&](int64, int64, float* v) { input_min = std::min(input_min, *v); });
+ ComputeAndCompareR0<float>(&builder, input_min, {}, ErrorSpec(0.0001));
+}
+
+// Reduces a matrix among dimension 1.
+XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+
+ std::vector<float> expected = {6.f, 15.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
+ // Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ ComputeAndCompareR0<float>(&builder, 21.0f, {}, ErrorSpec(0.0001, 1e-4));
+}
+
+// Tests 2D matrix ReduceToRow operation.
+XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
+ ComputationBuilder builder(client_, "reduce_among_y");
+ auto m = builder.ConstantLiteral(*literal_2d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+
+ std::vector<float> expected = {5.f, 7.f, 9.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1, 2});
+
+ std::vector<float> expected = {21.f, 21.f, 21.f, 21.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
+
+ std::vector<float> expected = {20.f, 28.f, 36.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1, 2});
+
+ float expected = 21.0f * 4.0;
+ ComputeAndCompareR0<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
+
+ // clang-format off
+ Array2D<float> expected({
+ {4.f, 8.f, 12.f},
+ {16.f, 20.f, 24.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
+
+ // clang-format off
+ Array2D<float> expected({
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ {5.f, 7.f, 9.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
+ ComputationBuilder builder(client_, TestName());
+ auto m = builder.ConstantLiteral(*literal_3d_);
+ auto add = CreateScalarAddComputation(F32, &builder);
+ builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {2});
+
+ // clang-format off
+ Array2D<float> expected({
+ {6.f, 15.f},
+ {6.f, 15.f},
+ {6.f, 15.f},
+ {6.f, 15.f},
+ });
+ // clang-format on
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+class ReduceR3ToR2Test : public ReduceTest,
+ public ::testing::WithParamInterface<BoundsLayout> {};
+
+XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
+ ComputationBuilder builder(client_, TestName());
+ const auto& bounds = GetParam().bounds;
+ Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
+ input_array.FillRandom(3.14f, 0.05);
+
+ auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array);
+ input_literal = LiteralUtil::Relayout(
+ *input_literal, LayoutUtil::MakeLayout(GetParam().layout));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ auto input_activations =
+ builder.Parameter(0, input_literal->shape(), "input");
+ Computation add = CreateScalarAddComputation(F32, &builder);
+ auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f),
+ add, GetParam().reduce_dims);
+
+ auto expected =
+ ReferenceUtil::Reduce3DTo2D(input_array, 0.0f, GetParam().reduce_dims,
+ [](float a, float b) { return a + b; });
+
+ ComputeAndCompareR2<float>(&builder, *expected, {input_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ ReduceR3ToR2Test_Instantiation, ReduceR3ToR2Test,
+ // Specifies (shape, layout, reduction dimensions).
+ ::testing::Values(BoundsLayout{{4, 8, 128}, {2, 1, 0}, {0}},
+ BoundsLayout{{4, 8, 128}, {2, 1, 0}, {1}},
+ BoundsLayout{{4, 8, 128}, {2, 1, 0}, {2}},
+ // These should be simplified into a reshape.
+ BoundsLayout{{1, 21, 43}, {2, 1, 0}, {0}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {0}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {1}},
+ BoundsLayout{{1, 1, 1}, {2, 1, 0}, {2}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {0}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {1}},
+ BoundsLayout{{8, 16, 24}, {0, 1, 2}, {2}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {0}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {1}},
+ BoundsLayout{{5, 10, 250}, {2, 1, 0}, {2}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {0}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {1}},
+ BoundsLayout{{8, 16, 256}, {2, 1, 0}, {2}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {2}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {1}},
+ BoundsLayout{{2, 300, 784}, {2, 1, 0}, {0}}));
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
new file mode 100644
index 0000000000..f48c14dfc6
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -0,0 +1,445 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests the reduce-window XLA operation.
+
+#include <limits>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReduceWindowTest : public ClientLibraryTestBase {
+ public:
+ ReduceWindowTest() : builder_(client_, TestName()) {}
+
+ void ReduceWindowAdd(ComputationDataHandle input,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ builder_.ReduceWindow(input, builder_.ConstantR0<float>(0.0f),
+ CreateScalarAddComputation(F32, &builder_),
+ window_dimensions, window_strides, padding);
+ }
+
+ void ReduceWindowMax(ComputationDataHandle input,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ builder_.ReduceWindow(
+ input, builder_.ConstantLiteral(LiteralUtil::MinValue(F32)),
+ CreateScalarMax(), window_dimensions, window_strides, padding);
+ }
+
+ void ReduceWindowMin(ComputationDataHandle input,
+ tensorflow::gtl::ArraySlice<int64> window_dimensions,
+ tensorflow::gtl::ArraySlice<int64> window_strides,
+ Padding padding) {
+ builder_.ReduceWindow(input,
+ builder_.ConstantLiteral(LiteralUtil::MaxValue(F32)),
+ CreateScalarMinComputation(F32, &builder_),
+ window_dimensions, window_strides, padding);
+ }
+
+ ComputationBuilder builder_;
+};
+
+XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) {
+ Array4D<float> input_array(1, 0, 2, 1);
+
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
+ {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, NonSquareSmall) {
+ Array4D<float> input_array(1, 2, 2, 1);
+ input_array.FillRandom(2.f);
+
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1},
+ {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, MiddleDimsSmall) {
+ Array4D<float> input_array(1, 3, 3, 1);
+ input_array.FillRandom(2.f);
+
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1},
+ {1, 2, 2, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, Along2ndMinorDim) {
+ Array4D<float> input_array(3, 6, 7, 32);
+ input_array.FillRandom(2.f);
+
+ // The parameters of this reduction mimic feature norm (e.g. LRN).
+ int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kSame;
+ ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) {
+ Array4D<float> input_array(9, 12, 4, 89);
+ input_array.FillRandom(2.0f);
+
+ int win_len = 3;
+ int win_stride = 2;
+
+ const auto input_data_handle =
+ builder_.ConstantR4FromArray4D<float>(input_array);
+
+ Padding padding = Padding::kSame;
+ // Reduce only along the x and y dimensions, according to the win_len.
+ ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ auto result = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+// TODO(b/32173947): Test support for arbitrary-sized padding.
+TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) {
+ Array4D<float> input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout
+ input_array.FillRandom(2.0f);
+
+ int64 rank = 4;
+ int win_len = 3;
+ int win_stride = 2;
+
+ const auto input_data_handle =
+ builder_.ConstantR4FromArray4D<float>(input_array);
+
+ Padding padding = Padding::kSame;
+ // Reduce only along the x and y dimensions, according to the win_len.
+ // Create padding vector with large padding values in the reduction dims.
+ std::vector<std::pair<int64, int64>> low_high_padding;
+ low_high_padding.resize(rank, {4, 4});
+
+ builder_.ReduceWindowWithGeneralPadding(
+ input_data_handle, builder_.ConstantR0<float>(0.0f),
+ CreateScalarAddComputation(F32, &builder_), {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, low_high_padding);
+
+ auto result = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {win_len, win_len, 1, 1},
+ {win_stride, win_stride, 1, 1}, padding);
+
+ ComputeAndCompareR4<float>(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3));
+}
+// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes.
+TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmall) {
+ Array4D<float> input_array(2, 2, 4, 16);
+
+ Array2D<float> yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,
+ 11.f, 12.f, 13.f, 14.f, 15.f},
+ {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f,
+ 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f},
+ {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f,
+ 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f},
+ {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f,
+ 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}});
+ input_array.FillWithYX(yx);
+
+ int win_len = 2;
+ int win_stride = 2;
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ Padding padding = Padding::kValid;
+ ReduceWindowAdd(input, {1, 1, win_len, win_len},
+ {1, 1, win_stride, win_stride}, padding);
+
+ auto res = ReferenceUtil::ReduceWindow4DAdd(
+ input_array, 0.0f, {1, 1, win_len, win_len},
+ {1, 1, win_stride, win_stride}, padding);
+ ComputeAndCompareR4<float>(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+// TODO(b/31809540): Implement minor dim reduction to reduce num of reshapes.
+TEST_F(ReduceWindowTest, ReduceR4AmongXYMinorSmallOverlapped) {
+ constexpr int64 p = 2;
+ constexpr int64 z = 2;
+ constexpr int64 y = 4;
+ constexpr int64 x = 16;
+ Array4D<float> input_array(p, z, y, x);
+
+ Array2D<float> yx({{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,
+ 11.f, 12.f, 13.f, 14.f, 15.f},
+ {16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f,
+ 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f},
+ {32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f,
+ 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f},
+ {48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f,
+ 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f}});
+ input_array.FillWithYX(yx);
+
+ int win_len = 4;
+ int win_stride = 2;
+ const auto input = builder_.ConstantR4FromArray4D<float>(input_array);
+ ReduceWindowAdd(input, {1, 1, win_len, win_len},
+ {1, 1, win_stride, win_stride}, Padding::kValid);
+
+ // Expected result
+ Array2D<float> yx_result({{408.f, 440.f, 472.f, 504.f, 536.f, 568.f, 600.f}});
+ Array4D<float> expected(p, z, 1, 7);
+ expected.FillWithYX(yx_result);
+ ComputeAndCompareR4<float>(&builder_, expected, {}, ErrorSpec(1e-3, 1e-3));
+}
+
+TEST_F(ReduceWindowTest, MaxTrivial) {
+ const auto input = builder_.ConstantR1<float>({42});
+ ReduceWindowMax(input, {1}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {42}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add3In3) {
+ const auto input = builder_.ConstantR1<float>({20, 100, 3});
+ ReduceWindowAdd(input, {3}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {123}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add4In16Stride4) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowAdd(input, {4}, {4}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {10, 26, 42, 58}, {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, DISABLED_ON_CPU(DISABLED_ON_GPU(Min3In5Stride2))) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 1});
+ ReduceWindowMin(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {100, 1}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max3In3) {
+ const auto input = builder_.ConstantR1<float>({20, 100, 3});
+ ReduceWindowMax(input, {3}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {100}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add2In3) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {2}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {110, 11}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add3In5Stride2) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 1});
+ ReduceWindowAdd(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {11100, 111}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max4In16Stride4) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowMax(input, {4}, {4}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {4, 8, 12, 16}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max4In16Stride3) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowMax(input, {4}, {3}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {4, 7, 10, 13, 16}, {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max4In16Stride8) {
+ const auto input = builder_.ConstantR1<float>(
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ ReduceWindowMax(input, {4}, {8}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {4, 12}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max3In5Stride2) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 1});
+ ReduceWindowMax(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {10000, 100}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Max3In5Stride1) {
+ const auto input = builder_.ConstantR1<float>({10000, 1000, 100, 10, 101});
+ ReduceWindowMax(input, {3}, {1}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {10000, 1000, 101}, {},
+ ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add3In4Stride2) {
+ const auto input = builder_.ConstantR1<float>({1000, 100, 10, 1});
+ ReduceWindowAdd(input, {3}, {2}, Padding::kValid);
+ ComputeAndCompareR1<float>(&builder_, {1110}, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add2In3SamePad) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {2}, {1}, Padding::kSame);
+ ComputeAndCompareR1<float>(&builder_, {110, 11, 1}, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add3In3SamePad) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {3}, {1}, Padding::kSame);
+ ComputeAndCompareR1<float>(&builder_, {110, 111, 11}, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add3In3Stride3SamePad) {
+ const auto input = builder_.ConstantR1<float>({100, 10, 1});
+ ReduceWindowAdd(input, {3}, {2}, Padding::kSame);
+ ComputeAndCompareR1<float>(&builder_, {110, 11}, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add2x2In2x2Overlapped) {
+ Array2D<float> input_array({{1.2f, -2.5f, 0.9f, 1.0f},
+ {3.7f, 0.2f, -1.0f, -0.2f},
+ {-0.4f, 2.7f, 1.1f, 2.2f},
+ {0.6f, 1.7f, 1.4f, -0.2f}});
+ auto input = builder_.ConstantR2FromArray2D<float>(input_array);
+ ReduceWindowAdd(input, {2, 2}, {1, 1}, Padding::kValid);
+ Array2D<float> expected(
+ {{2.6f, -2.4f, 0.7f}, {6.2f, 3.0f, 2.1f}, {4.6f, 6.9f, 4.5f}});
+ ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(ReduceWindowTest, Add2x2In2x2Disjoint) {
+ Array2D<float> input_array({{1.2f, -2.5f, 0.9f, 1.0f},
+ {3.7f, 0.2f, -1.0f, -0.2f},
+ {-0.4f, 2.7f, 1.1f, 2.2f},
+ {0.6f, 1.7f, 1.4f, -0.2f}});
+ auto input = builder_.ConstantR2FromArray2D<float>(input_array);
+ ReduceWindowAdd(input, {2, 2}, {2, 2}, Padding::kValid);
+ Array2D<float> expected({
+ {2.6f, 0.7f}, {4.6f, 4.5f},
+ });
+ ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) {
+ Array3D<float> input_array(2, 1, 2);
+ input_array(0, 0, 0) = 1000;
+ input_array(0, 0, 1) = 100;
+ input_array(1, 0, 0) = 10;
+ input_array(1, 0, 1) = 1;
+ auto input = builder_.ConstantR3FromArray3D<float>(input_array);
+
+ ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid);
+
+ Array3D<float> expected(2, 1, 1);
+ expected(0, 0, 0) = 1100;
+ expected(1, 0, 0) = 11;
+ ComputeAndCompareR3<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) {
+ Array3D<float> input_array(2, 1, 3);
+ input_array(0, 0, 0) = 100;
+ input_array(0, 0, 1) = 10;
+ input_array(0, 0, 2) = 1;
+ input_array(1, 0, 0) = 500;
+ input_array(1, 0, 1) = 50;
+ input_array(1, 0, 2) = 5;
+ auto input = builder_.ConstantR3FromArray3D<float>(input_array);
+
+ ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid);
+
+ Array3D<float> expected(2, 1, 1);
+ expected(0, 0, 0) = 110;
+ expected(1, 0, 0) = 550;
+ ComputeAndCompareR3<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) {
+ Array3D<float> input_array(2, 1, 3);
+ input_array(0, 0, 0) = 100;
+ input_array(0, 0, 1) = 10;
+ input_array(0, 0, 2) = 1;
+ input_array(1, 0, 0) = 500;
+ input_array(1, 0, 1) = 50;
+ input_array(1, 0, 2) = 5;
+ auto input = builder_.ConstantR3FromArray3D<float>(input_array);
+
+ ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame);
+
+ Array3D<float> expected(2, 1, 3);
+ expected(0, 0, 0) = 110;
+ expected(0, 0, 1) = 11;
+ expected(0, 0, 2) = 1;
+ expected(1, 0, 0) = 550;
+ expected(1, 0, 1) = 55;
+ expected(1, 0, 2) = 5;
+ ComputeAndCompareR3<float>(&builder_, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
new file mode 100644
index 0000000000..802087b508
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -0,0 +1,168 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/protobuf_util.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReplayTest : public ClientLibraryTestBase {};
+
+TEST_F(ReplayTest, TwoPlusTwoReplay) {
+ // Make 2+2 computation.
+ ComputationBuilder builder(client_, TestName());
+ auto two = builder.ConstantR0<int32>(2);
+ builder.Add(two, two);
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ // Serialize it out.
+ std::unique_ptr<SessionModule> module =
+ computation.Snapshot().ConsumeValueOrDie();
+
+ // Replay it.
+ Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+
+ // Check signature is the same.
+ std::unique_ptr<ProgramShape> original_shape =
+ client_->GetComputationShape(computation).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> replayed_shape =
+ client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+
+ // Run it.
+ std::unique_ptr<Literal> literal =
+ client_->ExecuteAndTransfer(replayed, /*arguments=*/{})
+ .ConsumeValueOrDie();
+
+ // Expect 4.
+ LiteralTestUtil::ExpectR0Equal<int32>(4, *literal);
+}
+
+XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
+ // Make computation.
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y");
+ builder.Add(x, y);
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ // Serialize it out.
+ std::unique_ptr<SessionModule> module =
+ computation.Snapshot().ConsumeValueOrDie();
+
+ // Replay it.
+ Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+
+ // Check signature is the same.
+ std::unique_ptr<ProgramShape> original_shape =
+ client_->GetComputationShape(computation).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> replayed_shape =
+ client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+
+ // Run it.
+ std::unique_ptr<GlobalData> x_data =
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(2))
+ .ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> y_data =
+ client_->TransferToServer(*LiteralUtil::CreateR0<int32>(3))
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> literal =
+ client_
+ ->ExecuteAndTransfer(replayed,
+ /*arguments=*/{x_data.get(), y_data.get()})
+ .ConsumeValueOrDie();
+
+ // Expect 5.
+ LiteralTestUtil::ExpectR0Equal<int32>(5, *literal);
+}
+
+TEST_F(ReplayTest, MapPlusTwoOverR1) {
+ // As above, but with map(+2) over some constant array.
+ ComputationBuilder plus_two_builder(client_, "plus two");
+ auto input =
+ plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input");
+ plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2));
+ Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
+
+ ComputationBuilder mapper_builder(client_, TestName());
+ auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
+ mapper_builder.Map({original}, plus_two);
+
+ Computation computation = mapper_builder.Build().ConsumeValueOrDie();
+
+ // Serialize it out.
+ std::unique_ptr<SessionModule> module =
+ computation.Snapshot().ConsumeValueOrDie();
+
+ // Replay it.
+ Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+
+ // Check signature is the same.
+ std::unique_ptr<ProgramShape> original_shape =
+ client_->GetComputationShape(computation).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> replayed_shape =
+ client_->GetComputationShape(replayed).ConsumeValueOrDie();
+ ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
+
+ // Destroy the originals.
+ computation.Reset();
+ plus_two.Reset();
+
+ // Run it.
+ std::unique_ptr<Literal> literal =
+ client_->ExecuteAndTransfer(replayed, /*arguments=*/{})
+ .ConsumeValueOrDie();
+
+ // Expect result.
+ LiteralTestUtil::ExpectR1Equal<int32>({3, 4, 5}, *literal);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
new file mode 100644
index 0000000000..ce309eb743
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -0,0 +1,77 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+using ReshapeMotionTest = ClientLibraryTestBase;
+
+TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<int32>({{2, 3, 5}, {7, 11, 13}});
+ auto b = builder.ConstantR2<int32>({{17, 19}, {23, 29}, {31, 37}});
+ auto c = builder.Reshape(a, {6});
+ auto d = builder.Reshape(b, {6});
+ auto e = builder.Mul(c, d);
+
+ ComputeAndCompareR1<int32>(&builder, {34, 57, 115, 203, 341, 481}, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc
new file mode 100644
index 0000000000..a9159d39ca
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reshape_test.cc
@@ -0,0 +1,811 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <random>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReshapeTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec zero_error_spec_{0.0};
+};
+
+// Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension.
+XLA_TEST_F(ReshapeTest, Trivial1x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0}});
+ builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {1.0f}, {}, zero_error_spec_);
+}
+
+// Collapses 2-dimensional pseudo-scalar (single-element array) to scalar.
+XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0}});
+ auto reshape =
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{});
+ auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<float>(&builder, 1.0f, {}, zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, Trivial0x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 3));
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, Trivial3x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, zero_error_spec_);
+}
+
+// Collapses a 2-dimensional row vector to 1 dimension.
+XLA_TEST_F(ReshapeTest, Trivial1x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0f, 2.0f, 3.0f}});
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {1.0f, 2.0f, 3.0f}, {},
+ zero_error_spec_);
+}
+
+// Collapses a 2-dimensional column vector to 1 dimension.
+XLA_TEST_F(ReshapeTest, Trivial3x1) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2<float>({{1.0f}, {2.0f}, {3.0f}});
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder, {1.0f, 2.0f, 3.0f}, {},
+ zero_error_spec_);
+}
+
+// Splits an empty vector into an empty matrix.
+XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({});
+ auto result =
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0});
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {},
+ zero_error_spec_);
+}
+
+// Splits a vector into a matrix.
+XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f});
+ auto result =
+ builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3});
+ Array2D<float> expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}});
+ ComputeAndCompareR2<float>(&builder, expected_2x3, {}, zero_error_spec_);
+}
+
+// Transposes a 2x0 array to a 0x2 array.
+XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 2));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {},
+ zero_error_spec_);
+}
+
+// Transposes a 2-dimensional row vector to a column vector.
+XLA_TEST_F(ReshapeTest, ReshapeRowToCol) {
+ ComputationBuilder builder(client_, TestName());
+ auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*simple);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{3, 1});
+
+ auto expected = ReferenceUtil::TransposeArray2D(*simple);
+ ComputeAndCompareR2<float>(&builder, *expected, {}, zero_error_spec_);
+}
+
+// Transposes a 2-dimensional array.
+XLA_TEST_F(ReshapeTest, TransposeAsReshape) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{3, 4});
+
+ auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3);
+ ComputeAndCompareR2<float>(&builder, *expected3x4, {}, zero_error_spec_);
+}
+
+// Transposes a 0x4 array with ComputationBuilder::Trans.
+XLA_TEST_F(ReshapeTest, Transpose0x4) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 4));
+ auto result = builder.Transpose(a, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(4, 0), {},
+ zero_error_spec_);
+}
+
+// Transposes a 2-dimensional array with ComputationBuilder::Trans.
+XLA_TEST_F(ReshapeTest, Transpose4x3) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Transpose(a, {1, 0});
+
+ auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3);
+ ComputeAndCompareR2<float>(&builder, *expected3x4, {}, zero_error_spec_);
+}
+
+// Reshapes an empty 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), but no reordering (no shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(6, 0));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 3, 0, 0});
+
+ ComputeAndCompareR4<float>(&builder, Array4D<float>(2, 3, 0, 0), {},
+ zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR4FromArray4D<float>(Array4D<float>(2, 3, 4, 0));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{24, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(24, 0), {},
+ zero_error_spec_);
+}
+
+// Reshapes a 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), but no reordering (no shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1},
+ /*new_sizes=*/{2, 6});
+
+ auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6);
+ ComputeAndCompareR2<float>(&builder, *expected2x6, {}, zero_error_spec_);
+}
+
+// Reshapes a 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), and reorder the input (shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 6));
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{3, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(3, 0), {},
+ zero_error_spec_);
+}
+
+// Reshapes a 2-dimensional array with dimensions that are not just a
+// rearrangement of the originals (split), and reorder the input (shuffle).
+XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) {
+ ComputationBuilder builder(client_, TestName());
+ auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3);
+ auto a = builder.ConstantR2FromArray2D<float>(*a4x3);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0},
+ /*new_sizes=*/{2, 6});
+
+ Array2D<float> expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f},
+ {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}});
+ ComputeAndCompareR2<float>(&builder, expected2x6, {}, zero_error_spec_);
+}
+
+// The following tests use the same input 3D array; they test the examples we
+// show for the Reshape operation in the operation_semantics document.
+// TODO(eliben): find a way to show this code in the documentation without
+// duplication.
+Array3D<int> v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}},
+ {{20, 21, 22}, {25, 26, 27}},
+ {{30, 31, 32}, {35, 36, 37}},
+ {{40, 41, 42}, {45, 46, 47}}});
+
+XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2},
+ /*new_sizes=*/{24});
+ ComputeAndCompareR1<int>(&builder,
+ {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
+ 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47},
+ {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2},
+ /*new_sizes=*/{8, 3});
+ Array2D<int> expected({{10, 11, 12},
+ {15, 16, 17},
+ {20, 21, 22},
+ {25, 26, 27},
+ {30, 31, 32},
+ {35, 36, 37},
+ {40, 41, 42},
+ {45, 46, 47}});
+ ComputeAndCompareR2<int>(&builder, expected, {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{24});
+ ComputeAndCompareR1<int>(&builder,
+ {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
+ 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47},
+ {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{8, 3});
+ Array2D<int> expected({{10, 20, 30},
+ {40, 11, 21},
+ {31, 41, 12},
+ {22, 32, 42},
+ {15, 25, 35},
+ {45, 16, 26},
+ {36, 46, 17},
+ {27, 37, 47}});
+ ComputeAndCompareR2<int>(&builder, expected, {});
+}
+
+XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) {
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR3FromArray3D<int>(v_array_for_doc_R3_tests);
+ auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0},
+ /*new_sizes=*/{2, 6, 2});
+ Array3D<int> expected(
+ {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}},
+ {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}});
+ ComputeAndCompareR3<int>(&builder, expected, {});
+}
+
+// Collapses the low dimensions of a 4D tensor to get a 2D matrix, without
+// reordering dimensions (for NeuralNet::FullyConnected).
+//
+// First we create a tesseract raster-face like:
+//
+// 1 2 3
+// 4 5 6
+//
+// First we collapse Y and X within the raster space yielding:
+//
+// 1 2 3 4 5 6
+//
+// Then we collapse Z be collapsed so we just end up with planes:
+//
+// 1 2 3 4 5 6 1 2 3 4 5 6
+XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> t2x2x2x3(2, 2, 2, 3);
+ auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3);
+ t2x2x2x3.FillWithYX(*filler2x3);
+ auto a = builder.ConstantR4FromArray4D<float>(t2x2x2x3);
+ auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3});
+
+ Array2D<float> expected2x12(
+ {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f},
+ {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
+ 6.0f}});
+ ComputeAndCompareR2<float>(&builder, expected2x12, {}, zero_error_spec_);
+}
+
+// As above, but uses reshape directly.
+XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> t(2, 1, 2, 2);
+ t(0, 0, 0, 0) = 0;
+ t(0, 0, 0, 1) = 1;
+ t(0, 0, 1, 0) = 2;
+ t(0, 0, 1, 1) = 3;
+ t(1, 0, 0, 0) = 4;
+ t(1, 0, 0, 1) = 5;
+ t(1, 0, 1, 0) = 6;
+ t(1, 0, 1, 1) = 7;
+ auto a = builder.ConstantR4FromArray4D<float>(t);
+ auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{2, 4});
+
+ Array2D<float> expected({{0, 1, 2, 3}, {4, 5, 6, 7}});
+ ComputeAndCompareR2<float>(&builder, expected, {}, zero_error_spec_);
+}
+
+// Reshape various ranks to a scalar.
+XLA_TEST_F(ReshapeTest, ToScalar) {
+ for (int rank = 0; rank < 8; ++rank) {
+ ComputationBuilder b(client_, TestName());
+ auto input = LiteralUtil::CreateR1<float>({83.0f});
+ std::vector<int64> ones(rank, 1); // this is {1, ..., 1}.
+ std::vector<int64> dimensions(rank);
+ std::iota(dimensions.begin(), dimensions.end(), 0);
+ *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones);
+ b.Reshape(b.ConstantLiteral(*input), dimensions, {});
+
+ ComputeAndCompareR0<float>(&b, 83.0f, {}, zero_error_spec_);
+ }
+}
+
+XLA_TEST_F(ReshapeTest, BadDimensions) {
+ ComputationBuilder b(client_, TestName());
+ b.Reshape(b.ConstantR1<int32>({1}), {}, {});
+ EXPECT_MATCH(ExecuteToString(&b, {}),
+ testing::HasSubstr("dimensions not a permutation"));
+}
+
+XLA_TEST_F(ReshapeTest, BadNewSizes) {
+ ComputationBuilder b(client_, TestName());
+ b.Reshape(b.ConstantR1<int32>({1, 2}), {1}, {});
+ EXPECT_MATCH(ExecuteToString(&b, {}),
+ testing::HasSubstr("mismatched element counts"));
+}
+
+XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
+ const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2});
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, parameter_shape, "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8});
+
+ // clang-format off
+ auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(Array4D<float>{
+ {
+ {
+ {0, 1},
+ {2, 3},
+ },
+ {
+ {100, 101},
+ {102, 103},
+ },
+ },
+ {
+ {
+ {222, 333},
+ {444, 555},
+ },
+ {
+ {666, 777},
+ {888, 999},
+ },
+ },
+ },
+ LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ // clang-format on
+ std::unique_ptr<GlobalData> input =
+ client_->TransferToServer(*literal).ConsumeValueOrDie();
+ Array2D<float> expected_array({
+ {0, 1, 2, 3, 100, 101, 102, 103},
+ {222, 333, 444, 555, 666, 777, 888, 999},
+ });
+
+ Computation computation = builder.Build().ConsumeValueOrDie();
+ const Shape shape_with_output_layout =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0});
+ std::unique_ptr<Literal> actual =
+ client_
+ ->ExecuteAndTransfer(computation, {input.get()},
+ &shape_with_output_layout)
+ .ConsumeValueOrDie();
+ std::unique_ptr<Literal> expected =
+ LiteralUtil::CreateR2FromArray2D<float>(expected_array);
+ LiteralTestUtil::ExpectEqual(*expected, *actual);
+}
+
+XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) {
+ std::unique_ptr<Literal> input = LiteralUtil::CreateR2<float>({
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {100, 101, 102, 103, 104, 105, 106, 107},
+ {200, 201, 202, 203, 204, 205, 206, 207},
+ });
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4});
+
+ // clang-format off
+ Array4D<float> expected = {
+ {{{0, 1, 2, 3}},
+ {{4, 5, 6, 7}}},
+ {{{100, 101, 102, 103}},
+ {{104, 105, 106, 107}}},
+ {{{200, 201, 202, 203}},
+ {{204, 205, 206, 207}}}
+ };
+ // clang-format on
+ ComputeAndCompareR4<float>(&builder, expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+// Tests R2->R4 reshape with the reshape dimensions {1, 0}.
+XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) {
+ std::unique_ptr<Literal> input = LiteralUtil::CreateR2<float>({
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {100, 101, 102, 103, 104, 105, 106, 107},
+ {200, 201, 202, 203, 204, 205, 206, 207},
+ });
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4});
+
+ // clang-format off
+ Array4D<float> expected = {
+ {{{0, 100, 200, 1}},
+ {{101, 201, 2, 102}}},
+ {{{202, 3, 103, 203}},
+ {{4, 104, 204, 5}}},
+ {{{105, 205, 6, 106}},
+ {{206, 7, 107, 207}}}
+ };
+ // clang-format on
+ ComputeAndCompareR4<float>(&builder, expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input(2, 1, 1, 1);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1});
+
+ std::unique_ptr<Literal> expected =
+ LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal);
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input(2, 1, 4, 1);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2});
+
+ std::unique_ptr<Literal> expected =
+ LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal);
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_);
+}
+
+// Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}.
+XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input(5, 10, 2, 3);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60});
+
+ Array2D<float> expected_array(5, 60);
+ input.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* cell) {
+ expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) =
+ *cell;
+ });
+ auto expected = LiteralUtil::CreateR2FromArray2D(expected_array);
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()});
+}
+
+XLA_TEST_F(ReshapeTest, NoopReshape) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ Array4D<float> input_array(2, 3, 5, 7);
+ input_array.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input_array, LayoutUtil::MakeLayout({1, 2, 3, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.Parameter(0, input_literal->shape(), "input");
+ builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2},
+ /*new_sizes=*/{7, 2, 3, 5});
+ Computation computation = builder.Build().ConsumeValueOrDie();
+
+ const Shape output_shape_with_layout =
+ ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1});
+ std::unique_ptr<Literal> output_literal =
+ client_
+ ->ExecuteAndTransfer(computation, {input_data.get()},
+ &output_shape_with_layout)
+ .ConsumeValueOrDie();
+
+ // Since the reshape is a no-op, verify that it does not change the underlying
+ // data.
+ EXPECT_EQ(tensorflow::gtl::ArraySlice<float>(input_literal->f32s()),
+ tensorflow::gtl::ArraySlice<float>(output_literal->f32s()));
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) {
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4(
+ {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
+ {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantLiteral(*literal_1x2x3x4);
+ builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3},
+ /*new_sizes=*/{1, 2, 3, 4});
+
+ ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {});
+}
+
+XLA_TEST_F(ReshapeTest, R4ToR4Reshape) {
+ auto literal_1x2x3x4 = LiteralUtil::CreateR4(
+ {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}},
+ {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}});
+
+ ComputationBuilder builder(client_, TestName());
+ auto input = builder.ConstantLiteral(*literal_1x2x3x4);
+ builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0},
+ /*new_sizes=*/{2, 4, 3, 1});
+
+ // clang-format off
+ auto expected_2x4x3x1 = LiteralUtil::CreateR4(
+ {{{{1}, {5}, {9}},
+ {{2}, {6}, {10}},
+ {{3}, {7}, {11}},
+ {{4}, {8}, {12}}},
+ {{{13}, {17}, {21}},
+ {{14}, {18}, {22}},
+ {{15}, {19}, {23}},
+ {{16}, {20}, {24}}}});
+ // clang-format on
+
+ ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {});
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {2, 2, 2, 2};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {1, 1, 250, 300};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {5, 5, 1, 10};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ // This happens in NN-Builder MNIST.
+ std::vector<int64> bounds = {5, 5, 10, 1};
+ std::vector<int64> new_bounds = {bounds[0], bounds[1], bounds[3], bounds[2]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({3, 2, 1, 0}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal),
+ LayoutUtil::MakeLayout({3, 2, 1, 0}));
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) {
+ std::mt19937 rng;
+ std::uniform_real_distribution<float> distribution;
+ std::vector<int64> bounds = {3, 3, 1, 3};
+ std::vector<int64> new_bounds = {bounds[1], bounds[0], bounds[2], bounds[3]};
+ Array4D<float> input(bounds[0], bounds[1], bounds[2], bounds[3]);
+ input.Each(
+ [&rng, &distribution](tensorflow::gtl::ArraySlice<int64> /* indices */,
+ float* cell) { *cell = distribution(rng); });
+ std::unique_ptr<Literal> input_literal =
+ LiteralUtil::CreateR4FromArray4DWithLayout(
+ input, LayoutUtil::MakeLayout({0, 1, 2, 3}));
+ std::unique_ptr<GlobalData> input_data =
+ client_->TransferToServer(*input_literal).ConsumeValueOrDie();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.Parameter(0, input_literal->shape(), "a");
+ builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds);
+
+ std::unique_ptr<Literal> expected = LiteralUtil::Relayout(
+ *LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal),
+ input_literal->shape().layout());
+
+ // Specify the requested output shape explicitly to ensure that this reshape
+ // actually corresponds to a two minor transpose.
+ ComputeAndCompareLiteral(&builder, *expected, {input_data.get()},
+ zero_error_spec_, &expected->shape());
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
new file mode 100644
index 0000000000..63dd4421fa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -0,0 +1,173 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ReverseTest : public ClientLibraryTestBase {};
+
+// Tests the reverse operation on a scalar.
+XLA_TEST_F(ReverseTest, ReverseScalar) {
+ ComputationBuilder b(client_, TestName());
+ float input = 3.5f;
+ b.Rev(b.ConstantR0<float>(input), {});
+ ComputeAndCompareR0<float>(&b, input, {});
+}
+
+// Tests the reverse operation on a 0x0 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse0x0FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR2FromArray2D<float>(Array2D<float>(0, 0)), {0, 1});
+ ComputeAndCompareR2<float>(&b, Array2D<float>(0, 0), {});
+}
+
+// Tests the reverse operation on a 0x1 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse0x1FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR2FromArray2D<float>(Array2D<float>(0, 1)), {0, 1});
+ ComputeAndCompareR2<float>(&b, Array2D<float>(0, 1), {});
+}
+
+// Tests the reverse operation on a 1x0 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse1x0FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR2FromArray2D<float>(Array2D<float>(1, 0)), {0, 1});
+ ComputeAndCompareR2<float>(&b, Array2D<float>(1, 0), {});
+}
+
+// Tests the reverse operation on a 1x1 float array on both dimensions.
+XLA_TEST_F(ReverseTest, Reverse1x1FloatArray) {
+ ComputationBuilder b(client_, TestName());
+ Array2D<float> input({{3.5f}});
+ b.Rev(b.ConstantR2FromArray2D<float>(input), {0, 1});
+ ComputeAndCompareR2<float>(&b, input, {});
+}
+
+XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim02) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR4FromArray4D<float>(Array4D<float>(2, 0, 4, 3)), {0, 2});
+ ComputeAndCompareR4<float>(&b, Array4D<float>(2, 0, 4, 3), {});
+}
+
+XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim13) {
+ ComputationBuilder b(client_, TestName());
+ b.Rev(b.ConstantR4FromArray4D<float>(Array4D<float>(2, 0, 4, 3)), {1, 3});
+ ComputeAndCompareR4<float>(&b, Array4D<float>(2, 0, 4, 3), {});
+}
+
+// Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
+XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
+ ComputationBuilder b(client_, TestName());
+ // Input shape is U8[1x2x3x4].
+ // clang-format off
+ Array4D<uint8> input({{
+ {{1, 2, 3, 4},
+ {5, 6, 7, 8},
+ {9, 10, 11, 12}},
+ {{13, 14, 15, 16},
+ {17, 18, 19, 20},
+ {21, 22, 23, 24}},
+ }});
+ // clang-format on
+
+ b.Rev(b.ConstantR4FromArray4D<uint8>(input), {0, 3});
+
+ // clang-format off
+ Array4D<uint8> expected({{
+ {{4, 3, 2, 1},
+ {8, 7, 6, 5},
+ {12, 11, 10, 9}},
+ {{16, 15, 14, 13},
+ {20, 19, 18, 17},
+ {24, 23, 22, 21}},
+ }});
+ // clang-format on
+ ComputeAndCompareR4<uint8>(&b, expected, {});
+}
+
+// Tests the reverse operation on a 4D float array on dimension 0 and 1.
+TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
+ ComputationBuilder b(client_, TestName());
+ // Input shape is float[4x3x2x1].
+ // clang-format off
+ Array4D<float> input({
+ {{{1.0f}, {2.0f}},
+ {{3.0f}, {4.0f}},
+ {{5.0f}, {6.0f}}},
+ {{{7.0f}, {8.0f}},
+ {{9.0f}, {10.0f}},
+ {{11.0f}, {12.0f}}},
+ {{{13.0f}, {14.0f}},
+ {{15.0f}, {16.0f}},
+ {{17.0f}, {18.0f}}},
+ {{{19.0f}, {20.0f}},
+ {{21.0f}, {22.0f}},
+ {{23.0f}, {24.0f}}},
+ });
+ // clang-format on
+
+ b.Rev(b.ConstantR4FromArray4D<float>(input), {0, 1});
+
+ // clang-format off
+ Array4D<float> expected({
+ {{{23.0f}, {24.0f}},
+ {{21.0f}, {22.0f}},
+ {{19.0f}, {20.0f}}},
+ {{{17.0f}, {18.0f}},
+ {{15.0f}, {16.0f}},
+ {{13.0f}, {14.0f}}},
+ {{{11.0f}, {12.0f}},
+ {{9.0f}, {10.0f}},
+ {{7.0f}, {8.0f}}},
+ {{{5.0f}, {6.0f}},
+ {{3.0f}, {4.0f}},
+ {{1.0f}, {2.0f}}},
+ });
+ // clang-format on
+ ComputeAndCompareR4<float>(&b, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
new file mode 100644
index 0000000000..5b734c0f40
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/packed_literal_reader.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class RoundTripPackedLiteralTest : public ClientLibraryTestBase {
+ protected:
+ // Sends the literal to the server and retrieves it back.
+ std::unique_ptr<Literal> RoundTripToServer(const Literal& original) {
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(original).ConsumeValueOrDie();
+ return client_->Transfer(*data).ConsumeValueOrDie();
+ }
+};
+
+TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) {
+ string data(sizeof(float) * 2, 0);
+ tensorflow::gtl::MutableArraySlice<float> floats(
+ tensorflow::bit_cast<float*>(data.data()), 2);
+ floats[0] = 42.0;
+ floats[1] = 24.0;
+
+ string fname = tensorflow::testing::TmpDir() + "/RoundTripsR1F32Length2.data";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data)
+ .ok());
+
+ std::unique_ptr<tensorflow::RandomAccessFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
+ PackedLiteralReader reader(f.release());
+ std::unique_ptr<Literal> actual =
+ reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie();
+ EXPECT_TRUE(reader.IsExhausted());
+
+ EXPECT_EQ(42.0, LiteralUtil::Get<float>(*actual, {0}));
+ EXPECT_EQ(24.0, LiteralUtil::Get<float>(*actual, {1}));
+}
+
+TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) {
+ string data(sizeof(float) * 4, 0);
+ tensorflow::gtl::MutableArraySlice<float> floats(
+ tensorflow::bit_cast<float*>(data.data()), 4);
+ // With x as the minor dimension, these will become:
+ floats[0] = 42.0; // y=0,x=0
+ floats[1] = 24.0; // y=0,x=1
+ floats[2] = 64.0; // y=1,x=0
+ floats[3] = 46.0; // y=1,x=1
+
+ string fname =
+ tensorflow::testing::TmpDir() + "/RoundTripsR2F32Size2x2Dim0Minor.data";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data)
+ .ok());
+
+ const Layout layout = LayoutUtil::MakeLayout({1, 0});
+
+ std::unique_ptr<tensorflow::RandomAccessFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
+ PackedLiteralReader reader(f.release());
+ std::unique_ptr<Literal> actual =
+ reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
+ EXPECT_TRUE(reader.IsExhausted());
+
+ EXPECT_EQ(42.0f, LiteralUtil::Get<float>(*actual, {0, 0}));
+ EXPECT_EQ(24.0f, LiteralUtil::Get<float>(*actual, {0, 1}));
+ EXPECT_EQ(64.0f, LiteralUtil::Get<float>(*actual, {1, 0}));
+ EXPECT_EQ(46.0f, LiteralUtil::Get<float>(*actual, {1, 1}));
+
+ std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
+ LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+}
+
+TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) {
+ string data(sizeof(float) * 4, 0);
+ tensorflow::gtl::MutableArraySlice<float> floats(
+ tensorflow::bit_cast<float*>(data.data()), 4);
+ // With y as the minor dimension, these will become:
+ floats[0] = 42.0; // y=0,x=0
+ floats[1] = 24.0; // y=1,x=0
+ floats[2] = 64.0; // y=0,x=1
+ floats[3] = 46.0; // y=1,x=1
+
+ string fname =
+ tensorflow::testing::TmpDir() + "/RoundTripsR2F32Size2x2Dim1Minor.data";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, data)
+ .ok());
+
+ const Layout layout = LayoutUtil::MakeLayout({0, 1});
+
+ std::unique_ptr<tensorflow::RandomAccessFile> f;
+ TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(fname, &f));
+ PackedLiteralReader reader(f.release());
+ std::unique_ptr<Literal> actual =
+ reader.Read(ShapeUtil::MakeShape(F32, {2, 2}), &layout)
+ .ConsumeValueOrDie();
+ EXPECT_TRUE(reader.IsExhausted());
+
+ EXPECT_EQ(42.0f, LiteralUtil::Get<float>(*actual, {0, 0}));
+ EXPECT_EQ(24.0f, LiteralUtil::Get<float>(*actual, {1, 0}));
+ EXPECT_EQ(64.0f, LiteralUtil::Get<float>(*actual, {0, 1}));
+ EXPECT_EQ(46.0f, LiteralUtil::Get<float>(*actual, {1, 1}));
+
+ std::unique_ptr<Literal> round_tripped = RoundTripToServer(*actual);
+ LiteralTestUtil::ExpectEqual(*round_tripped, *actual);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
new file mode 100644
index 0000000000..04a8bab0eb
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests transferring literals of various shapes and values in and out of the
+// XLA service.
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class RoundTripTransferTest : public ClientLibraryTestBase {
+ protected:
+ void RoundTripTest(const Literal& original) {
+ std::unique_ptr<GlobalData> data =
+ client_->TransferToServer(original).ConsumeValueOrDie();
+ std::unique_ptr<Literal> result =
+ client_->Transfer(*data).ConsumeValueOrDie();
+ LiteralTestUtil::ExpectEqual(original, *result);
+ }
+};
+
+TEST_F(RoundTripTransferTest, R0S32) {
+ RoundTripTest(*LiteralUtil::CreateR0<int32>(42));
+}
+
+TEST_F(RoundTripTransferTest, R0F32) {
+ RoundTripTest(*LiteralUtil::CreateR0<float>(42.0));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len0) {
+ RoundTripTest(*LiteralUtil::CreateR1<float>({}));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len2) {
+ RoundTripTest(*LiteralUtil::CreateR1<float>({42.0, 64.0}));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len256) {
+ std::vector<float> values(256);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len1024) {
+ std::vector<float> values(1024);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len1025) {
+ std::vector<float> values(1025);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R1F32_Len4096) {
+ std::vector<float> values(4096);
+ std::iota(values.begin(), values.end(), 1.0);
+ RoundTripTest(*LiteralUtil::CreateR1<float>(values));
+}
+
+TEST_F(RoundTripTransferTest, R2F32_Len10x0) {
+ RoundTripTest(
+ *LiteralUtil::CreateR2FromArray2D<float>(Array2D<float>(10, 0)));
+}
+
+TEST_F(RoundTripTransferTest, R2F32_Len2x2) {
+ RoundTripTest(*LiteralUtil::CreateR2<float>({{42.0, 64.0}, {77.0, 88.0}}));
+}
+
+TEST_F(RoundTripTransferTest, R3F32) {
+ RoundTripTest(
+ *LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
+ {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}));
+}
+
+TEST_F(RoundTripTransferTest, R4F32) {
+ RoundTripTest(*LiteralUtil::CreateR4<float>({{
+ {{10, 11, 12, 13}, {14, 15, 16, 17}},
+ {{18, 19, 20, 21}, {22, 23, 24, 25}},
+ {{26, 27, 28, 29}, {30, 31, 32, 33}},
+ }}));
+}
+
+TEST_F(RoundTripTransferTest, EmptyTuple) {
+ RoundTripTest(*LiteralUtil::MakeTuple({}));
+}
+
+TEST_F(RoundTripTransferTest, TupleOfR1F32) {
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({1, 2}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
+}
+
+TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) {
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>({}).get(),
+ LiteralUtil::CreateR1<float>({3, 4}).get()}));
+}
+
+TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) {
+ RoundTripTest(
+ *LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(1.0).get(),
+ LiteralUtil::CreateR1<int>({2, 3}).get()}));
+}
+
+// Below two tests are added to identify the cost of large data transfers.
+TEST_F(RoundTripTransferTest, R2F32_Large) {
+ RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512));
+}
+
+TEST_F(RoundTripTransferTest, R4F32_Large) {
+ Array4D<float> array4d(2, 2, 256, 256);
+ array4d.FillWithMultiples(1.0f);
+ RoundTripTest(*LiteralUtil::CreateR4FromArray4D<float>(array4d));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
new file mode 100644
index 0000000000..bd9cae4d1d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc
@@ -0,0 +1,630 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <limits>
+#include <memory>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/legacy_flags/llvm_backend_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class ScalarComputationsTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+
+ protected:
+ // A template for building and running a binary comparison test.
+ template <typename NativeT>
+ void TestCompare(NativeT lhs, NativeT rhs, bool expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<NativeT>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<NativeT>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<bool>(&builder, expected, {});
+ }
+
+ template <typename NativeT>
+ void TestMinMax(NativeT lhs, NativeT rhs, NativeT expected,
+ ComputationDataHandle (ComputationBuilder::*op)(
+ const ComputationDataHandle&,
+ const ComputationDataHandle&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle lhs_op = builder.ConstantR0<NativeT>(lhs);
+ ComputationDataHandle rhs_op = builder.ConstantR0<NativeT>(rhs);
+ ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ ComputeAndCompareR0<NativeT>(&builder, expected, {});
+ }
+};
+
+TEST_F(ScalarComputationsTest, NegateScalarF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<float>(2.1f));
+
+ ComputeAndCompareR0<float>(&builder, -2.1f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, NegateScalarS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Neg(builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -2, {});
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+
+ ComputeAndCompareR0<float>(&builder, 7.6f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+
+ ComputeAndCompareR0<int32>(&builder, 7, {});
+}
+
+TEST_F(ScalarComputationsTest, AddTwoScalarsU32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<uint32>(35), builder.ConstantR0<uint32>(57));
+
+ ComputeAndCompareR0<uint32>(&builder, 92, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU8) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<uint8>(35), builder.ConstantR0<uint8>(57));
+
+ ComputeAndCompareR0<uint8>(&builder, 92, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsU64) {
+ ComputationBuilder builder(client_, TestName());
+ const uint64 a = static_cast<uint64>(1) << 63;
+ const uint64 b = a + 1;
+ builder.Add(builder.ConstantR0<uint64>(a), builder.ConstantR0<uint64>(b));
+
+ ComputeAndCompareR0<uint64>(&builder, a + b, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsS64) {
+ ComputationBuilder builder(client_, TestName());
+ const int64 a = static_cast<int64>(1) << 62;
+ const int64 b = a + 1;
+ builder.Add(builder.ConstantR0<int64>(a), builder.ConstantR0<int64>(b));
+
+ ComputeAndCompareR0<int64>(&builder, a + b, {});
+}
+
+XLA_TEST_F(ScalarComputationsTest, AddTwoScalarsF64) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Add(builder.ConstantR0<double>(0.25),
+ builder.ConstantR0<double>(3.5));
+
+ ComputeAndCompareR0<double>(&builder, 3.75, {});
+}
+
+TEST_F(ScalarComputationsTest, SubtractTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Sub(builder.ConstantR0<float>(2.1f), builder.ConstantR0<float>(5.5f));
+
+ ComputeAndCompareR0<float>(&builder, -3.4f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Sub(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5));
+
+ ComputeAndCompareR0<int32>(&builder, -3, {});
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
+ builder.ConstantR0<float>(5.5f)),
+ builder.ConstantR0<float>(0.5f));
+
+ ComputeAndCompareR0<float>(&builder, 5.775f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, MulTwoScalarsS32) {
+ std::vector<int32> data = {0,
+ 1,
+ -1,
+ 1234,
+ 0x1a243514,
+ std::numeric_limits<int32>::max(),
+ std::numeric_limits<int32>::min()};
+
+ for (int32 x : data) {
+ for (int32 y : data) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.ConstantR0<int32>(x), builder.ConstantR0<int32>(y));
+
+ // Signed integer overflow is undefined behavior in C++. Convert the input
+ // integers to unsigned, perform the multiplication unsigned, and convert
+ // back.
+ int32 expected = static_cast<uint32>(x) * static_cast<uint32>(y);
+
+ ComputeAndCompareR0<int32>(&builder, expected, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, MulTwoScalarsU32) {
+ std::vector<uint32> data = {0, 1, 0xDEADBEEF, 1234,
+ 0x1a243514, 0xFFFFFFFF, 0x80808080};
+
+ for (uint32 x : data) {
+ for (uint32 y : data) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(builder.ConstantR0<uint32>(x), builder.ConstantR0<uint32>(y));
+
+ uint32 expected = x * y;
+ ComputeAndCompareR0<uint32>(&builder, expected, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Mul(
+ builder.Mul(builder.ConstantR0<int32>(2), builder.ConstantR0<int32>(5)),
+ builder.ConstantR0<int32>(1));
+
+ ComputeAndCompareR0<int32>(&builder, 10, {});
+}
+
+TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) {
+ ComputationBuilder builder(client_, TestName());
+ std::unique_ptr<Literal> a_literal = LiteralUtil::CreateR0<float>(2.1f);
+ std::unique_ptr<Literal> b_literal = LiteralUtil::CreateR0<float>(5.5f);
+ std::unique_ptr<Literal> c_literal = LiteralUtil::CreateR0<float>(0.5f);
+
+ std::unique_ptr<GlobalData> a_data =
+ client_->TransferToServer(*a_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> b_data =
+ client_->TransferToServer(*b_literal).ConsumeValueOrDie();
+ std::unique_ptr<GlobalData> c_data =
+ client_->TransferToServer(*c_literal).ConsumeValueOrDie();
+
+ ComputationDataHandle a = builder.Parameter(0, a_literal->shape(), "a");
+ ComputationDataHandle b = builder.Parameter(1, b_literal->shape(), "b");
+ ComputationDataHandle c = builder.Parameter(2, c_literal->shape(), "c");
+ builder.Mul(builder.Mul(a, b), c);
+
+ ComputeAndCompareR0<float>(&builder, 5.775f,
+ {a_data.get(), b_data.get(), c_data.get()},
+ error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, DivideTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Div(builder.ConstantR0<float>(5.0f), builder.ConstantR0<float>(2.5f));
+
+ ComputeAndCompareR0<float>(&builder, 2.0f, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsF32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<float>(2.5f), builder.ConstantR0<float>(5.0f));
+
+ ComputeAndCompareR0<float>(&builder, 2.5f, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Div(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -2, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsNegativeResultS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(-5), builder.ConstantR0<int32>(2));
+
+ ComputeAndCompareR0<int32>(&builder, -1, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(INT_MIN),
+ builder.ConstantR0<int32>(7919));
+
+ ComputeAndCompareR0<int32>(&builder, -1309, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsIntMinVsIntMaxS32) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Rem(builder.ConstantR0<int32>(INT_MIN),
+ builder.ConstantR0<int32>(INT_MAX));
+
+ ComputeAndCompareR0<int32>(&builder, -1, {});
+}
+
+TEST_F(ScalarComputationsTest, RemainderTwoScalarsPositiveResultS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
+ builder.Rem(x, builder.ConstantR0<int32>(80000));
+
+ std::unique_ptr<Literal> literal = LiteralUtil::CreateR0<int32>(87919);
+ TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal));
+ ComputeAndCompareR0<int32>(&builder, 7919, {input_data.get()});
+}
+
+XLA_TEST_F(ScalarComputationsTest, DivideTwoScalarsU32) {
+ ComputationBuilder builder(client_, TestName());
+ // This verifies 0xFFFFFFFE / 2 = 0x7FFFFFFF. If XLA incorrectly treated U32
+ // as S32, it would output -2 / 2 = -1 (0xFFFFFFFF).
+ builder.Div(builder.ConstantR0<uint32>(0xFFFFFFFE),
+ builder.ConstantR0<uint32>(2));
+
+ ComputeAndCompareR0<uint32>(&builder, 0x7FFFFFFF, {});
+}
+
+TEST_F(ScalarComputationsTest, LogicalAnd) {
+ for (bool x : {false, true}) {
+ for (bool y : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalAnd(builder.ConstantR0<bool>(x),
+ builder.ConstantR0<bool>(y));
+
+ ComputeAndCompareR0<bool>(&builder, x && y, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, LogicalOr) {
+ for (bool x : {false, true}) {
+ for (bool y : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalOr(builder.ConstantR0<bool>(x),
+ builder.ConstantR0<bool>(y));
+
+ ComputeAndCompareR0<bool>(&builder, x || y, {});
+ }
+ }
+}
+
+TEST_F(ScalarComputationsTest, LogicalNot) {
+ for (bool x : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ builder.LogicalNot(builder.ConstantR0<bool>(x));
+
+ ComputeAndCompareR0<bool>(&builder, !x, {});
+ }
+}
+
+TEST_F(ScalarComputationsTest, SelectScalarTrue) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Select(builder.ConstantR0<bool>(true), // The predicate.
+ builder.ConstantR0<float>(123.0f), // The value on true.
+ builder.ConstantR0<float>(42.0f)); // The value on false.
+
+ ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, SelectScalarFalse) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Select(builder.ConstantR0<bool>(false), // The predicate.
+ builder.ConstantR0<float>(123.0f), // The value on true.
+ builder.ConstantR0<float>(42.0f)); // The value on false.
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
+}
+
+// This test is an explicit version of what is happening in the following
+// templatized comparison tests.
+TEST_F(ScalarComputationsTest, CompareGtScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Gt(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(1.0f));
+
+ ComputeAndCompareR0<bool>(&builder, true, {});
+}
+
+// S32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqS32Greater) {
+ TestCompare<int32>(2, 1, false, &ComputationBuilder::Eq);
+}
+TEST_F(ScalarComputationsTest, CompareEqS32Equal) {
+ TestCompare<int32>(3, 3, true, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeS32) {
+ TestCompare<int32>(2, 1, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeS32) {
+ TestCompare<int32>(2, 1, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtS32) {
+ TestCompare<int32>(1, 5, false, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeS32) {
+ TestCompare<int32>(2, 1, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtS32) {
+ TestCompare<int32>(9, 7, false, &ComputationBuilder::Lt);
+ TestCompare<int32>(std::numeric_limits<int32>::min(),
+ std::numeric_limits<int32>::max(), true,
+ &ComputationBuilder::Lt);
+}
+
+// U32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqU32False) {
+ TestCompare<uint32>(2, 1, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeU32) {
+ TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeU32Greater) {
+ TestCompare<uint32>(2, 1, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeU32Equal) {
+ TestCompare<uint32>(3, 3, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtU32) {
+ TestCompare<uint32>(1, 5, false, &ComputationBuilder::Gt);
+ TestCompare<uint32>(5, 5, false, &ComputationBuilder::Gt);
+ TestCompare<uint32>(5, 1, true, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeU32) {
+ TestCompare<uint32>(2, 1, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtU32) {
+ TestCompare<uint32>(9, 7, false, &ComputationBuilder::Lt);
+ TestCompare<uint32>(0, std::numeric_limits<uint32>::max(), true,
+ &ComputationBuilder::Lt);
+}
+
+// F32 comparisons.
+TEST_F(ScalarComputationsTest, CompareEqF32False) {
+ TestCompare<float>(2.0, 1.3, false, &ComputationBuilder::Eq);
+}
+
+TEST_F(ScalarComputationsTest, CompareNeF32) {
+ TestCompare<float>(2.0, 1.3, true, &ComputationBuilder::Ne);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeF32Greater) {
+ TestCompare<float>(2.0, 1.9, true, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32Equal) {
+ TestCompare<float>(3.5, 3.5, true, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, CompareGtF32) {
+ TestCompare<float>(1.0, 5.2, false, &ComputationBuilder::Gt);
+}
+
+TEST_F(ScalarComputationsTest, CompareLeF32) {
+ TestCompare<float>(2.0, 1.2, false, &ComputationBuilder::Le);
+}
+
+TEST_F(ScalarComputationsTest, CompareLtF32) {
+ TestCompare<float>(9.0, 7.2, false, &ComputationBuilder::Lt);
+}
+
+// F32 comparisons with exceptional values. The test names encode the
+// left/right operands at the end, and use Minf and Mzero for -inf and -0.0.
+TEST_F(ScalarComputationsTest, CompareLtF32MinfMzero) {
+ TestCompare<float>(-INFINITY, -0.0, true, &ComputationBuilder::Lt);
+}
+TEST_F(ScalarComputationsTest, CompareLtF32MzeroZero) {
+ // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
+ TestCompare<float>(-0.0, 0.0, false, &ComputationBuilder::Lt);
+}
+TEST_F(ScalarComputationsTest, CompareLtF32ZeroInf) {
+ TestCompare<float>(0.0, INFINITY, true, &ComputationBuilder::Lt);
+}
+
+TEST_F(ScalarComputationsTest, CompareGeF32MinfMzero) {
+ TestCompare<float>(-INFINITY, -0.0, false, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32MzeroZero) {
+ // Comparisons of 0.0 to -0.0 consider them equal in IEEE 754.
+ TestCompare<float>(-0.0, 0.0, true, &ComputationBuilder::Ge);
+}
+TEST_F(ScalarComputationsTest, CompareGeF32ZeroInf) {
+ TestCompare<float>(0.0, INFINITY, false, &ComputationBuilder::Ge);
+}
+
+TEST_F(ScalarComputationsTest, ExpScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Exp(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 7.3890562, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, LogScalar) {
+ ComputationBuilder builder(client_, "log");
+ builder.Log(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 0.6931471, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, TanhScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tanh(builder.ConstantR0<float>(2.0f));
+
+ ComputeAndCompareR0<float>(&builder, 0.96402758, {}, error_spec_);
+}
+
+XLA_TEST_F(ScalarComputationsTest, TanhDoubleScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Tanh(builder.ConstantR0<double>(2.0));
+
+ ComputeAndCompareR0<double>(&builder, 0.96402758, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, PowScalar) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Pow(builder.ConstantR0<float>(2.0f), builder.ConstantR0<float>(3.0f));
+
+ ComputeAndCompareR0<float>(&builder, 8.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarHigh) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(5.0f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 3.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarMiddle) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(2.5f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 2.5, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ClampScalarLow) {
+ ComputationBuilder builder(client_, TestName());
+ builder.Clamp(builder.ConstantR0<float>(2.0f), // The lower bound.
+ builder.ConstantR0<float>(-5.0f), // The operand to be clamped.
+ builder.ConstantR0<float>(3.0f)); // The upper bound.
+
+ ComputeAndCompareR0<float>(&builder, 2.0, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, MinS32Above) {
+ TestMinMax<int32>(10, 3, 3, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinS32Below) {
+ TestMinMax<int32>(-100, 3, -100, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxS32Above) {
+ TestMinMax<int32>(10, 3, 10, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxS32Below) {
+ TestMinMax<int32>(-100, 3, 3, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MinU32Above) {
+ const uint32 large = std::numeric_limits<int32>::max();
+ TestMinMax<uint32>(large, 3, 3, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinU32Below) {
+ TestMinMax<uint32>(0, 5, 0, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxU32Above) {
+ const uint32 large = std::numeric_limits<int32>::max();
+ TestMinMax<uint32>(large, 3, large, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxU32Below) {
+ TestMinMax<uint32>(0, 5, 5, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MinF32Above) {
+ TestMinMax<float>(10.1f, 3.1f, 3.1f, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MinF32Below) {
+ TestMinMax<float>(-100.1f, 3.1f, -100.1f, &ComputationBuilder::Min);
+}
+
+TEST_F(ScalarComputationsTest, MaxF32Above) {
+ TestMinMax<float>(10.1f, 3.1f, 10.1f, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, MaxF32Below) {
+ TestMinMax<float>(-100.1f, 3.1f, 3.1f, &ComputationBuilder::Max);
+}
+
+TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionF32) {
+ // Compute the expression (1 * (3 - 1) * (7 + 0) - 4) / 20.
+ ComputationBuilder b(client_, TestName());
+ b.Div(
+ b.Sub(b.Mul(b.ConstantR0<float>(1),
+ b.Mul(b.Sub(b.ConstantR0<float>(3), b.ConstantR0<float>(1)),
+ b.Add(b.ConstantR0<float>(7), b.ConstantR0<float>(0)))),
+ b.ConstantR0<float>(4)),
+ b.ConstantR0<float>(20));
+
+ ComputeAndCompareR0<float>(&b, 0.5, {}, error_spec_);
+}
+
+TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) {
+ // Compute the expression 1 * (3 - 1) * (7 + 0) - 4.
+ ComputationBuilder b(client_, TestName());
+ b.Sub(b.Mul(b.ConstantR0<int32>(1),
+ b.Mul(b.Sub(b.ConstantR0<int32>(3), b.ConstantR0<int32>(1)),
+ b.Add(b.ConstantR0<int32>(7), b.ConstantR0<int32>(0)))),
+ b.ConstantR0<int32>(4));
+
+ ComputeAndCompareR0<int32>(&b, 10, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendLlvmBackendFlags(&flag_list);
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
new file mode 100644
index 0000000000..fb1effc8c4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc
@@ -0,0 +1,395 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests the select-and-scatter XLA operation.
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SelectAndScatterTest : public ClientLibraryTestBase {
+ public:
+ SelectAndScatterTest() : builder_(client_, TestName()) {
+ // Create S32 GE and ADD computations for select and scatter respectively.
+ ge_s32_ = CreateScalarGeComputation(S32, &builder_);
+ add_s32_ = CreateScalarAddComputation(S32, &builder_);
+ ge_f32_ = CreateScalarGeComputation(F32, &builder_);
+ add_f32_ = CreateScalarAddComputation(F32, &builder_);
+ max_f32_ = CreateScalarMaxComputation(F32, &builder_);
+ min_f32_ = CreateScalarMinComputation(F32, &builder_);
+ }
+
+ ComputationBuilder builder_;
+ Computation ge_s32_;
+ Computation add_s32_;
+ Computation ge_f32_;
+ Computation add_f32_;
+ Computation max_f32_;
+ Computation min_f32_;
+};
+
+// Test for F32 1D array, with a zero-element input.
+XLA_TEST_F(SelectAndScatterTest, R1S0F32) {
+ const auto operand = builder_.ConstantR1<float>({});
+ const auto source = builder_.ConstantR1<float>({});
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR1<float>(&builder_, {}, {}, ErrorSpec(1e-7));
+}
+
+// Test for F32 1D array, when windows do not overlap.
+XLA_TEST_F(SelectAndScatterTest, R1F32) {
+ const auto operand =
+ builder_.ConstantR1<float>({1.f, 9.f, 3.f, 7.f, 5.f, 6.f});
+ const auto source = builder_.ConstantR1<float>({34.f, 42.f});
+ const std::vector<float> expected = {0.f, 34.f, 0.f, 42.f, 0.f, 0.f};
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+// Test for S32 1D array, when windows do not overlap and the init value is 1.
+XLA_TEST_F(SelectAndScatterTest, R1S32) {
+ const auto operand = builder_.ConstantR1<int32>({-1, 0, 6, 4, -4, 10});
+ const auto source = builder_.ConstantR1<int32>({-10, 20});
+ const std::vector<int32> expected = {1, 1, -9, 1, 1, 21};
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{3}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(1), add_s32_);
+ ComputeAndCompareR1<int32>(&builder_, expected, {});
+}
+
+// Test for S32 1D array, when windows overlap with each other.
+XLA_TEST_F(SelectAndScatterTest, R1S32OverlappingWindow) {
+ const auto operand = builder_.ConstantR1<int32>({1, 9, 3, 7, 5, 6});
+ const auto source = builder_.ConstantR1<int32>({34, 42, 53, 19});
+ const std::vector<int32> expected = {0, 76, 0, 72, 0, 0};
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{3},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR1<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when windows do not overlap.
+XLA_TEST_F(SelectAndScatterTest, R2S32) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 10, 2}, {3, 8, 9, 3, 4, 2}});
+ const auto source = builder_.ConstantR2<int32>({{2, 6}});
+ Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{2, 3}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Similar to SelectAndScatterTest.R2S32 but the input is transposed.
+XLA_TEST_F(SelectAndScatterTest, ReshapeR2S32) {
+ const auto operand = builder_.ConstantR2<int32>(
+ {{7, 3}, {2, 8}, {5, 9}, {3, 3}, {10, 4}, {2, 2}});
+ const auto reshape =
+ builder_.Reshape(operand, /*dimensions=*/{1, 0}, /*new_sizes=*/{2, 6});
+ const auto source = builder_.ConstantR2<int32>({{2, 6}});
+ Array2D<int32> expected({{0, 0, 0, 0, 6, 0}, {0, 0, 2, 0, 0, 0}});
+ builder_.SelectAndScatter(reshape, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{2, 3}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when windows overlap with each other.
+XLA_TEST_F(SelectAndScatterTest, R2S32OverlappingWindow) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source = builder_.ConstantR2<int32>({{2, 6, 4}});
+ Array2D<int32> expected({{0, 0, 0, 0, 0}, {0, 0, 12, 0, 0}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 3},
+ /*window_strides=*/{1, 1}, Padding::kValid, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when the padding is Padding::kSAME.
+XLA_TEST_F(SelectAndScatterTest, R2S32SamePadding) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source = builder_.ConstantR2<int32>({{2, 6, 4}});
+ Array2D<int32> expected({{0, 0, 0, 0, 4}, {0, 2, 6, 0, 0}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{2, 2}, Padding::kSame, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+// Test for S32 2D array, when the padding is Padding::kSAME and windows overlap
+// with each other.
+XLA_TEST_F(SelectAndScatterTest, R2S32SamePaddingOverlappingWindow) {
+ const auto operand =
+ builder_.ConstantR2<int32>({{7, 2, 5, 3, 8}, {3, 8, 9, 3, 4}});
+ const auto source =
+ builder_.ConstantR2<int32>({{2, 6, 4, 7, 1}, {3, 5, 8, 9, 10}});
+ Array2D<int32> expected({{0, 0, 0, 0, 8}, {0, 5, 23, 0, 19}});
+ builder_.SelectAndScatter(operand, ge_s32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{1, 1}, Padding::kSame, source,
+ builder_.ConstantR0<int32>(0), add_s32_);
+ ComputeAndCompareR2<int32>(&builder_, expected, {});
+}
+
+XLA_TEST_F(SelectAndScatterTest, R2F32OverlappingR2Source) {
+ const auto operand = builder_.ConstantR2<float>(
+ {{1.5f, 2.5f, 1.5f}, {3.5f, 1.5f, 3.5f}, {4.5f, 2.5f, 4.5f}});
+ const auto source = builder_.ConstantR2<float>({{1.0f, 2.0f}, {3.0f, 4.0f}});
+ Array2D<float> expected(
+ {{0.0f, 0.0f, 0.0f}, {1.0f, 0.0f, 2.0f}, {3.0f, 0.0f, 4.0f}});
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{2, 2},
+ /*window_strides=*/{1, 1}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR2<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32Valid) {
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
+ {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 6.0f, 0.0f},
+ {0.0f, 0.0f, 2.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}};
+ Array4D<float> o(4, 6, 15, 220);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> e(4, 6, 15, 220);
+ e.FillWithPZ(pze);
+ Array4D<float> s(2, 2, 15, 220);
+ s.FillWithPZ(pzs);
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32Overlap) {
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
+ {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
+ Array4D<float> o(4, 5, 17, 128);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> e(4, 5, 17, 128);
+ e.FillWithPZ(pze);
+ Array4D<float> s(2, 2, 17, 128);
+ s.FillWithPZ(pzs);
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32OverlapSmall) {
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 8.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.0f},
+ {0.0f, 6.0f, 2.0f, 10.0f, 2.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array2D<float> pze = {{0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 8.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 3.0f, 0.0f, 0.0f},
+ {0.0f, 0.0f, 0.0f, 1.0f, 0.0f}};
+ Array4D<float> o(4, 5, 1, 1);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> e(4, 5, 1, 1);
+ e.FillWithPZ(pze);
+ Array4D<float> s(2, 2, 1, 1);
+ s.FillWithPZ(pzs);
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 2, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ ComputeAndCompareR4<float>(&builder_, e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefValidFixedSmall) {
+ // This test is testing the Reference Util
+ Array2D<float> pzo = {{7.0f, 2.0f, 5.0f, 3.0f, 10.0f, 2.0f},
+ {3.0f, 8.0f, 9.0f, 3.0f, 4.00f, 2.0f},
+ {1.0f, 5.0f, 7.0f, 5.0f, 6.00f, 1.0f},
+ {0.0f, 6.0f, 2.0f, 7.0f, 2.00f, 8.0f}};
+ Array2D<float> pzs = {{2.0f, 6.0f}, {3.0f, 1.0f}};
+ Array4D<float> o(4, 6, 4, 4);
+ o.FillWithPZ(pzo);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+ Array4D<float> s(2, 2, 4, 4);
+ s.FillWithPZ(pzs);
+
+ auto source = builder_.ConstantR4FromArray4D(s);
+ s.FillWithPZ(pzs);
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 3, 1, 1}, {2, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 3, 1, 1},
+ {2, 3, 1, 1}, false);
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefSameRandom) {
+ Array4D<float> o(7, 7, 8, 256);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(4, 4, 8, 256);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {2, 2, 1, 1}, {2, 2, 1, 1},
+ Padding::kSame, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {2, 2, 1, 1},
+ {2, 2, 1, 1}, true);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefSameRandomFullyPadded) {
+ Array4D<float> o(1, 1, 5, 5);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(1, 1, 5, 5);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1},
+ Padding::kSame, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1},
+ {3, 3, 1, 1}, true);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefValidRandom) {
+ Array4D<float> o(9, 9, 16, 128);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(3, 3, 16, 128);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1},
+ {3, 3, 1, 1}, false);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+TEST_F(SelectAndScatterTest, R4F32RefValidRandomSmall) {
+ Array4D<float> o(3, 3, 4, 4);
+ o.FillRandom(1.5f);
+ auto operand = builder_.ConstantR4FromArray4D(o);
+
+ Array4D<float> s(1, 1, 4, 4);
+ s.FillRandom(12.0f);
+ auto source = builder_.ConstantR4FromArray4D(s);
+
+ builder_.SelectAndScatter(operand, ge_f32_, {3, 3, 1, 1}, {3, 3, 1, 1},
+ Padding::kValid, source,
+ builder_.ConstantR0<float>(0.0f), add_f32_);
+
+ auto e = ReferenceUtil::SelectAndScatter4DGePlus(o, s, 0.0f, {3, 3, 1, 1},
+ {3, 3, 1, 1}, false);
+
+ ComputeAndCompareR4<float>(&builder_, *e, {}, ErrorSpec(1e-7));
+}
+
+XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMaxScatter) {
+ const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1});
+ const auto source = builder_.ConstantR1<float>({34, 42, 53, 19});
+ const std::vector<float> expected = {0, 0, 0, 53, 0, 0, 0};
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ builder_.ConstantR0<float>(0), max_f32_);
+ ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) {
+ const auto operand = builder_.ConstantR1<float>({1, 2, 3, 100, 3, 2, 1});
+ const auto source = builder_.ConstantR1<float>({34, 42, 53, 19});
+ const float max_float = std::numeric_limits<float>::max();
+ const std::vector<float> expected = {max_float, max_float, max_float, 19,
+ max_float, max_float, max_float};
+ builder_.SelectAndScatter(operand, ge_f32_, /*window_dimensions=*/{4},
+ /*window_strides=*/{1}, Padding::kValid, source,
+ builder_.ConstantR0<float>(max_float), min_f32_);
+ ComputeAndCompareR1<float>(&builder_, expected, {}, ErrorSpec(1e-7));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc
new file mode 100644
index 0000000000..5ec9ac95fa
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/select_test.cc
@@ -0,0 +1,276 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SelectTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+TEST_F(SelectTest, SelectScalarF32True) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR0<float>(123.0f);
+ auto on_false = builder.ConstantR0<float>(42.0f);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<float>(&builder, 123.0f, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectScalarS32True) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR0<int32>(-42);
+ auto on_false = builder.ConstantR0<int32>(42);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<int32>(&builder, -42, {});
+}
+
+TEST_F(SelectTest, SelectScalarF32False) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto on_true = builder.ConstantR0<float>(123.0f);
+ auto on_false = builder.ConstantR0<float>(42.0f);
+ auto result = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR0<float>(&builder, 42.0f, {}, error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithConstantR1S0PRED) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR1<bool>({});
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithConstantR1PRED) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR1<bool>({false, true, false, true, false});
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithCmpR1S0S32s) {
+ // Similar to SelectR1S0F32WithConstantR1S0PRED, except that the pred vector
+ // is not a constant, but rather the result of comparing two other vectors.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<int32>({});
+ auto v2 = builder.ConstantR1<int32>({});
+ auto cmp = builder.Eq(v1, v2);
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1S32s) {
+ // Similar to SelectR1F32WithConstantR1PRED, except that the pred vector is
+ // not a constant, but rather the result of comparing two other vectors.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<int32>({1, 2, 3, 4, 5});
+ auto v2 = builder.ConstantR1<int32>({9, 2, 9, 4, 9});
+ auto cmp = builder.Eq(v1, v2);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 25.5f, 1.0f, -10.0f, -6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32s) {
+ // Similar to SelectR1F32WithCmpR1S32s, except "gt"-comparing two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v1 = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f});
+ auto v2 = builder.ConstantR1<float>({-1.0f, -2.0f, 13.0f, 14.0f, 4.4f});
+ auto cmp = builder.Gt(v1, v2);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f, 2.25f, -10.0f, 6.0f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f, 1.0f, 10.0f, -6.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f, 1.0f, 10.0f, 6.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsSmall) {
+ // Selects among two R1F32s, which come from parameters. v1 and v2 are
+ // compared, and selection between them happens based on a gt-comparison mask.
+ ComputationBuilder builder(client_, TestName());
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
+ {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
+ {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto cmp = builder.Gt(v1, v2);
+ auto select = builder.Select(cmp, v1, v2);
+ ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32sFromParamsLarge) {
+ // Similar to SelectR1F32WithCmpR1F32sFromParamsSmall, except that the
+ // data size passed in and out is large.
+ ComputationBuilder builder(client_, TestName());
+
+ // Number of floats in the data passed into and out of the computation.
+ constexpr int datalen = 15 * 1000;
+
+ // The inputs are initialized with a special pattern where in the first third
+ // of the data v1[i] > v2[i] and elsewhere it's vice versa.
+ std::vector<float> v1vec;
+ std::vector<float> v2vec;
+ std::vector<float> expected_vec;
+ for (int i = 0; i < datalen; ++i) {
+ float smaller = i;
+ float larger = i * 2;
+ if (i < datalen / 3) {
+ v1vec.push_back(larger);
+ v2vec.push_back(smaller);
+ } else {
+ v1vec.push_back(smaller);
+ v2vec.push_back(larger);
+ }
+ expected_vec.push_back(larger);
+ }
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data =
+ CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data =
+ CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto cmp = builder.Gt(v1, v2);
+ auto select = builder.Select(cmp, v1, v2);
+ ComputeAndCompareR1<float>(&builder, expected_vec,
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1S32ToScalar) {
+ // "gt"-compares a R1S32 with a S32 scalar, and uses the resulting R1PRED to
+ // select between two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<int32>({1, -1, 2, -2});
+ auto s = builder.ConstantR0<int32>(0);
+ auto cmp = builder.Gt(v, s);
+
+ auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_false =
+ builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {11.0f, -222.0f, 33.0f, -444.0f}, {},
+ error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithCmpR1F32ToScalar) {
+ // "gt"-compares a R1F32 with a F32 scalar, and uses the resulting R1PRED to
+ // select between two R1F32s.
+ ComputationBuilder builder(client_, TestName());
+ auto v = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f, 4.0f});
+ auto s = builder.ConstantR0<float>(2.5f);
+ auto cmp = builder.Gt(v, s);
+
+ auto on_true = builder.ConstantR1<float>({11.0f, 22.0f, 33.0f, 44.0f});
+ auto on_false =
+ builder.ConstantR1<float>({-111.0f, -222.0f, -333.0f, -444.0f});
+ auto select = builder.Select(cmp, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-111.0f, -222.0f, 33.0f, 44.0f}, {},
+ error_spec_);
+}
+
+XLA_TEST_F(SelectTest, SelectR1S0F32WithScalarPredicate) {
+ for (bool which : {false, true}) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(which);
+ auto on_true = builder.ConstantR1<float>({});
+ auto on_false = builder.ConstantR1<float>({});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, error_spec_);
+ }
+}
+
+TEST_F(SelectTest, SelectR1F32WithScalarPredicateTrue) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(true);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {-2.5f, 25.5f}, {}, error_spec_);
+}
+
+TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) {
+ ComputationBuilder builder(client_, TestName());
+ auto pred = builder.ConstantR0<bool>(false);
+ auto on_true = builder.ConstantR1<float>({-2.5f, 25.5f});
+ auto on_false = builder.ConstantR1<float>({10.0f, 5.0f});
+ auto select = builder.Select(pred, on_true, on_false);
+
+ ComputeAndCompareR1<float>(&builder, {10.0f, 5.0f}, {}, error_spec_);
+}
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc
new file mode 100644
index 0000000000..e15d744d95
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc
@@ -0,0 +1,116 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class SetReturnValueTest : public ClientLibraryTestBase {};
+
+TEST_F(SetReturnValueTest, NoSetValue) {
+ ComputationBuilder builder(client_, "no_set_value");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+
+ std::vector<float> expected = {1.0, 3.0, 4.0, 0.0, -1.0,
+ 5.0, 6.0, -2.0, -3.0, 7.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(SetReturnValueTest, SetValue) {
+ ComputationBuilder builder(client_, "set_value");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+ auto builder_status = builder.SetReturnValue(ax);
+ EXPECT_TRUE(builder_status.ok());
+
+ std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
+ 4.0, 5.0, -3.0, -4.0, 6.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(SetReturnValueTest, SetValueAndModify) {
+ ComputationBuilder builder(client_, "set_value_and_modify");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+ auto builder_status = builder.SetReturnValue(ax);
+ EXPECT_TRUE(builder_status.ok());
+ auto aaax = builder.Add(alpha, aax);
+
+ std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
+ 4.0, 5.0, -3.0, -4.0, 6.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) {
+ ComputationBuilder builder(client_, "set_value_multiple_times_and_modify");
+ auto alpha = builder.ConstantR0<float>(1.0);
+ auto x = builder.ConstantR1<float>(
+ {-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
+ auto ax = builder.Add(alpha, x);
+ auto aax = builder.Add(alpha, ax);
+ auto builder_status = builder.SetReturnValue(aax);
+ EXPECT_TRUE(builder_status.ok());
+ auto aaax = builder.Add(alpha, aax);
+ builder_status = builder.SetReturnValue(ax);
+ EXPECT_TRUE(builder_status.ok());
+ auto aaaax = builder.Add(alpha, aaax);
+
+ std::vector<float> expected = {0.0, 2.0, 3.0, -1.0, -2.0,
+ 4.0, 5.0, -3.0, -4.0, 6.0};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
new file mode 100644
index 0000000000..d63582fb98
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -0,0 +1,277 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tests that slice operations can be performed.
+
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class SliceTest : public ClientLibraryTestBase {
+ protected:
+ template <typename NativeT>
+ void RunSliceTenToTwo() {
+ std::vector<NativeT> constant;
+ for (int i = 0; i < 10; ++i) {
+ constant.push_back(static_cast<NativeT>(i));
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<NativeT>(constant);
+ builder.Slice(original, {2}, {4});
+
+ const std::vector<NativeT> expected = {static_cast<NativeT>(2),
+ static_cast<NativeT>(3)};
+ ComputeAndCompareR1<NativeT>(&builder, expected, {});
+ }
+};
+
+XLA_TEST_F(SliceTest, SliceZeroToZeroF32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>({});
+ builder.Slice(original, {0}, {0});
+
+ ComputeAndCompareR1<float>(&builder, {}, {});
+}
+
+XLA_TEST_F(SliceTest, SliceTenToZeroF32) {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> constant(10, 0.3);
+ auto original = builder.ConstantR1<float>(constant);
+ builder.Slice(original, {7}, {7});
+
+ ComputeAndCompareR1<float>(&builder, {}, {});
+}
+
+TEST_F(SliceTest, SliceTenToTwoF32) { RunSliceTenToTwo<float>(); }
+
+XLA_TEST_F(SliceTest, SliceTenToTwoF64) { RunSliceTenToTwo<double>(); }
+
+TEST_F(SliceTest, SliceTenToTwoU32) { RunSliceTenToTwo<uint32>(); }
+
+TEST_F(SliceTest, SliceTenToTwoS32) { RunSliceTenToTwo<int32>(); }
+
+XLA_TEST_F(SliceTest, SliceTenToTwoU64) { RunSliceTenToTwo<uint64>(); }
+
+XLA_TEST_F(SliceTest, SliceTenToTwoS64) { RunSliceTenToTwo<int64>(); }
+
+TEST_F(SliceTest, SliceTenToTen) {
+ const std::vector<float> values = {0.0, 1.0, 2.0, 3.0, 4.0,
+ 5.0, 6.0, 7.0, 8.0, 9.0};
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>(values);
+ builder.Slice(original, {0}, {10});
+
+ ComputeAndCompareR1<float>(&builder, values, {}, ErrorSpec(0.000001));
+}
+
+TEST_F(SliceTest, SliceLastFourOf1024) {
+ std::vector<float> values(1024);
+ std::iota(values.begin(), values.end(), 0.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>(values);
+ builder.Slice(original, {1024 - 4}, {1024});
+
+ const std::vector<float> expected = {1020, 1021, 1022, 1023};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// TODO(b/28491443): Fix wrong result on CPU and GPU. Failed on
+// 2016-05-01. Also b/28508652
+TEST_F(SliceTest, DISABLED_SliceUnaligned1024In4096Values) {
+ std::vector<float> values(4096);
+ std::iota(values.begin(), values.end(), 0.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR1<float>(values);
+ builder.Slice(original, {7}, {7 + 1024});
+
+ std::vector<float> expected(1024);
+ std::iota(values.begin(), values.end(), 7.0);
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+XLA_TEST_F(SliceTest, Slice0x0to0x0F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
+ builder.Slice(original, {0, 0}, {0, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {});
+}
+
+XLA_TEST_F(SliceTest, Slice0x20to0x5F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 20));
+ builder.Slice(original, {0, 15}, {0, 20});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 5), {});
+}
+
+XLA_TEST_F(SliceTest, Slice3x0to2x0F32) {
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(Array2D<float>(3, 0));
+ builder.Slice(original, {1, 0}, {3, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(2, 0), {});
+}
+
+XLA_TEST_F(SliceTest, SliceQuadrantOf256x256) {
+ Array2D<float> values(256, 256);
+ for (int row = 0; row < 256; ++row) {
+ for (int col = 0; col < 256; ++col) {
+ values(row, col) = (row << 10) | col;
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(values);
+ builder.Slice(original, {128, 128}, {256, 256});
+
+ Array2D<float> expected(128, 128);
+ for (int row = 0; row < 128; ++row) {
+ for (int col = 0; col < 128; ++col) {
+ expected(row, col) = ((row + 128) << 10) | (col + 128);
+ }
+ }
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// Tests: (f32[1,4096], starts={0, 3072}, limits={1, 4096}) -> f32[1,1024])
+TEST_F(SliceTest, Slice_1x4096_To_1x1024) {
+ Array2D<float> values(1, 4096);
+ std::iota(values.data(), values.data() + 4096, 0.0);
+
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(values);
+ builder.Slice(original, {0, 3072}, {1, 4096});
+
+ Array2D<float> expected(1, 1024);
+ std::iota(expected.data(), expected.data() + 1024, 3072.0);
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// Tests slice: (f32[16,4], starts={0, 0}, limits={16, 2}) -> f32[16,2]
+TEST_F(SliceTest, Slice_16x4_To_16x2) {
+ Array2D<float> values(16, 4);
+ Array2D<float> expected(16, 2);
+ for (int row = 0; row < 16; ++row) {
+ for (int col = 0; col < 4; ++col) {
+ values(row, col) = (row << 10) | col;
+ if (col < 2) {
+ expected(row, col) = (row << 10) | col;
+ }
+ }
+ }
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR2FromArray2D<float>(values);
+ builder.Slice(original, {0, 0}, {16, 2});
+ ComputeAndCompareR2<float>(&builder, expected, {}, ErrorSpec(0.000001));
+}
+
+// Tests: (f32[2, 2, 24, 256], starts = {1, 0, 8, 0}, ends = {2, 2, 16, 128}
+TEST_F(SliceTest, SliceR4ThreeDimsMiddleMinor) {
+ Array4D<float> values(2, 2, 24, 256);
+ values.FillRandom(3.14f);
+ auto expected =
+ ReferenceUtil::Slice4D(values, {{1, 0, 8, 0}}, {{2, 2, 16, 128}});
+ ComputationBuilder builder(client_, TestName());
+ auto original = builder.ConstantR4FromArray4D(values);
+ builder.Slice(original, {1, 0, 8, 0}, {2, 2, 16, 128});
+ ComputeAndCompareR4(&builder, *expected, {}, ErrorSpec(0.000001));
+}
+
+struct R2Spec {
+ int64 input_dim0;
+ int64 input_dim1;
+ std::array<int64, 2> slice_starts;
+ std::array<int64, 2> slice_limits;
+ Layout layout;
+};
+
+// Parameterized test that generates patterned R2 values, slices them according
+// to the R2Spec, and compares the results with the ReferenceUtil version.
+class SliceR2Test : public ClientLibraryTestBase,
+ public ::testing::WithParamInterface<R2Spec> {};
+
+TEST_P(SliceR2Test, DoIt) {
+ const R2Spec& spec = GetParam();
+ Array2D<int32> input(spec.input_dim0, spec.input_dim1);
+ input.FillUnique();
+
+ ComputationBuilder builder(client_, TestName());
+ auto a = builder.ConstantR2FromArray2D<int32>(input);
+ builder.Slice(a, spec.slice_starts, spec.slice_limits);
+
+ std::unique_ptr<Array2D<int32>> expected =
+ ReferenceUtil::Slice2D(input, spec.slice_starts, spec.slice_limits);
+ ComputeAndCompareR2<int32>(&builder, *expected, {});
+}
+
+// clang-format off
+INSTANTIATE_TEST_CASE_P(
+ SliceR2TestInstantiation, SliceR2Test,
+ ::testing::Values(
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {4, 12, {{0, 3}}, {{4, 6}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({0, 1})},
+ R2Spec {16, 4, {{0, 2}}, {{16, 4}}, LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {256, 400, {{0, 300}}, {{256, 400}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {500, 400, {{111, 123}}, {{300, 257}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {500, 400, {{111, 123}}, {{300, 400}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {384, 512, {{128, 256}}, {{256, 384}},
+ LayoutUtil::MakeLayout({1, 0})},
+ R2Spec {357, 512, {{111, 256}}, {{301, 384}},
+ LayoutUtil::MakeLayout({1, 0})}
+ )
+);
+// clang-format on
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
new file mode 100644
index 0000000000..7f987a21ca
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Macros for use in enabling/disabling tests on particular
+// platforms. Marking a gunit test as disabled still ensures that it
+// compiles.
+//
+// Implementation note: the macros are structured as follows:
+// * Define the disabled macro to just pass the test name through (which, in
+// effect, does not disable it at all)
+// * If a XLA_TEST_BACKEND_$TARGET macro indicates we're compiling for
+// $TARGET platform, make the disabled macro truly disable the test; i.e. by
+// redefining the DISABLED_ON_$TARGET macro to prepend "DISABLED_" to the test
+// name.
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
+
+#include <string>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+// Use this macro instead of directly using TEST_P for parameterized tests,
+// otherwise DISABLED_ON_* macros nested in TEST_P will not get expanded since
+// TEST_P stringifies its argument. That makes the test disabled for all targets
+// when any one of the DISABLED_ON_* macro is used, and the test will just pass.
+// TODO(b/29122096): Remove this once TEST_P fixes this problem.
+#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name)
+
+#define DISABLED_ON_CPU(X) X
+#define DISABLED_ON_CPU_PARALLEL(X) X
+#define DISABLED_ON_GPU(X) X
+
+// We need this macro instead of pasting directly to support nesting
+// the DISABLED_ON_FOO macros, as in the definition of DISABLED_ON_CPU.
+// Otherwise the pasting is applied before macro expansion completes.
+#define XLA_TEST_PASTE(A, B) A##B
+
+// We turn off clang-format so we can indent the macros for readability.
+// clang-format off
+
+#ifdef XLA_TEST_BACKEND_CPU
+# undef DISABLED_ON_CPU
+# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif // XLA_TEST_BACKEND_CPU
+
+#ifdef XLA_TEST_BACKEND_CPU_PARALLEL
+# undef DISABLED_ON_CPU
+# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X)
+# undef DISABLED_ON_CPU_PARALLEL
+# define DISABLED_ON_CPU_PARALLEL(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif // XLA_TEST_BACKEND_CPU_PARALLEL
+
+#ifdef XLA_TEST_BACKEND_GPU
+# undef DISABLED_ON_GPU
+# define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X)
+#endif // XLA_TEST_BACKEND_GPU
+
+// clang-format on
+
+#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name)
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_
diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h
new file mode 100644
index 0000000000..6a23df4d3c
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/test_utils.h
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
+#define TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
+
+#include <initializer_list>
+#include <memory>
+#include <random>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace test_utils {
+
+// A class which generates pseudorandom numbers of a given type within a given
+// range. Not cryptographically secure and likely not perfectly evenly
+// distributed across the range but sufficient for most tests.
+template <typename NativeT>
+class PseudorandomGenerator {
+ public:
+ explicit PseudorandomGenerator(NativeT min_value, NativeT max_value,
+ uint32 seed)
+ : min_(min_value), max_(max_value), generator_(seed) {}
+
+ // Get a pseudorandom value.
+ NativeT get() {
+ std::uniform_real_distribution<> distribution;
+ return static_cast<NativeT>(min_ +
+ (max_ - min_) * distribution(generator_));
+ }
+
+ private:
+ NativeT min_;
+ NativeT max_;
+ std::mt19937 generator_;
+};
+
+// Convenience function for creating a rank-2 array with arbitrary layout.
+template <typename NativeT>
+std::unique_ptr<Literal> CreateR2LiteralWithLayout(
+ std::initializer_list<std::initializer_list<NativeT>> values,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ auto literal = MakeUnique<Literal>();
+ const int64 d0 = values.size();
+ const int64 d1 = values.begin()->size();
+ LiteralUtil::PopulateWithValue<NativeT>(0, {d0, d1}, literal.get());
+ *literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout(minor_to_major);
+ TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto value : inner_list) {
+ LiteralUtil::Set(literal.get(), {dim0, dim1}, value);
+ ++dim1;
+ }
+ ++dim0;
+ }
+ return literal;
+}
+
+// Convenience function for creating a rank-3 array with arbitrary layout.
+template <typename NativeT>
+std::unique_ptr<Literal> CreateR3LiteralWithLayout(
+ std::initializer_list<std::initializer_list<std::initializer_list<NativeT>>>
+ values,
+ tensorflow::gtl::ArraySlice<int64> minor_to_major) {
+ auto literal = MakeUnique<Literal>();
+ const int64 d0 = values.size();
+ const int64 d1 = values.begin()->size();
+ const int64 d2 = values.begin()->begin()->size();
+ LiteralUtil::PopulateWithValue<NativeT>(0, {d0, d1, d2}, literal.get());
+ *literal->mutable_shape()->mutable_layout() =
+ LayoutUtil::MakeLayout(minor_to_major);
+ TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape()));
+
+ int64 dim0 = 0;
+ for (auto inner_list : values) {
+ int64 dim1 = 0;
+ for (auto inner_inner_list : inner_list) {
+ int64 dim2 = 0;
+ for (auto value : inner_inner_list) {
+ LiteralUtil::Set(literal.get(), {dim0, dim1, dim2}, value);
+ ++dim2;
+ }
+ ++dim1;
+ }
+ ++dim0;
+ }
+ return literal;
+}
+
+} // namespace test_utils
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_UTILS_H_
diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc
new file mode 100644
index 0000000000..79f251bbc4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/transpose_test.cc
@@ -0,0 +1,203 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/reference_util.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class TransposeTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+
+ protected:
+ void TestTransposeConstant021(size_t n1, size_t n2, size_t n3);
+};
+
+XLA_TEST_F(TransposeTest, Transpose0x0) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 0));
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(TransposeTest, Transpose0x42) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 42));
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(42, 0), {}, error_spec_);
+}
+
+XLA_TEST_F(TransposeTest, Transpose7x0) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2FromArray2D<float>(Array2D<float>(7, 0));
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 7), {}, error_spec_);
+}
+
+TEST_F(TransposeTest, Transpose2x2) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto lhs = builder.ConstantR2<float>({
+ {1.0, 2.0}, {3.0, 4.0},
+ });
+ auto result = builder.Transpose(lhs, {1, 0});
+
+ Array2D<float> expected({{1.0f, 3.0f}, {2.0f, 4.0f}});
+
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TransposeTest, Transpose0x2x3_2x3x0) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>(Array3D<int32>(0, 2, 3));
+ auto result = builder.Transpose(operand, {1, 2, 0});
+
+ ComputeAndCompareR3<int32>(&builder, Array3D<int32>(2, 3, 0), {});
+}
+
+TEST_F(TransposeTest, Transpose1x2x3_2x3x1) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
+ auto result = builder.Transpose(operand, {1, 2, 0});
+
+ Array3D<int32> expected({{{1}, {2}, {3}}, {{4}, {5}, {6}}});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, Transpose1x2x3_3x2x1) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
+ auto result = builder.Transpose(operand, {2, 1, 0});
+
+ Array3D<int32> expected({{{1}, {4}}, {{2}, {5}}, {{3}, {6}}});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, Transpose1x2x3_1x2x3) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto operand = builder.ConstantR3FromArray3D<int32>({{{1, 2, 3}, {4, 5, 6}}});
+ auto result = builder.Transpose(operand, {0, 1, 2});
+
+ Array3D<int32> expected({{{1, 2, 3}, {4, 5, 6}}});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, MultiTranspose3x2) {
+ Array2D<float> input({{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}});
+ Array2D<float> transposed({{1.0f, 3.0f, 5.0f}, {2.0f, 4.0f, 6.0f}});
+
+ for (int transposes = 0; transposes <= 10; ++transposes) {
+ ComputationBuilder builder(client_, "Transpose");
+ auto computed = builder.ConstantR2FromArray2D<float>(input);
+ for (int i = 0; i < transposes; ++i) {
+ computed = builder.Transpose(computed, {1, 0});
+ }
+ const Array2D<float>& expected = transposes % 2 == 0 ? input : transposed;
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+ }
+}
+
+// Test for transposing [1x1] matrix.
+TEST_F(TransposeTest, Small_1x1) {
+ auto aoperand = MakeLinspaceArray2D(0.0, 1.0, 1, 1);
+
+ ComputationBuilder builder(client_, "transpose_1x1");
+ auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
+ builder.Transpose(operand, {1, 0});
+
+ auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
+ ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
+}
+
+// Test for transposing [2x2] matrix.
+TEST_F(TransposeTest, Small_2x2) {
+ auto aoperand = MakeLinspaceArray2D(0.0, 4.0, 2, 2);
+
+ ComputationBuilder builder(client_, "transpose_2x2");
+ auto operand = builder.ConstantR2FromArray2D<float>(*aoperand);
+ builder.Transpose(operand, {1, 0});
+
+ auto expected = ReferenceUtil::TransposeArray2D(*aoperand);
+ ComputeAndCompareR2<float>(&builder, *expected, {}, ErrorSpec(1e-4));
+}
+
+void TransposeTest::TestTransposeConstant021(size_t n1, size_t n2, size_t n3) {
+ Array3D<int32> aoperand(n1, n2, n3);
+ Array3D<int32> expected(n1, n3, n2);
+ for (size_t i = 0; i < n1; ++i) {
+ for (size_t j = 0; j < n2; ++j) {
+ for (size_t k = 0; k < n3; ++k) {
+ aoperand(i, j, k) = i * n3 * n2 + j * n3 + k;
+ expected(i, k, j) = aoperand(i, j, k);
+ }
+ }
+ }
+
+ ComputationBuilder builder(client_, TestName());
+ auto operand = builder.ConstantR3FromArray3D(aoperand);
+ builder.Transpose(operand, {0, 2, 1});
+
+ ComputeAndCompareR3<int32>(&builder, expected, {});
+}
+
+TEST_F(TransposeTest, TransposeConstant021_SingleIncompleteTilePerLayer) {
+ TestTransposeConstant021(2, 2, 3);
+}
+
+TEST_F(TransposeTest, TransposeConstant021_SingleCompleteTilePerLayer) {
+ TestTransposeConstant021(2, 32, 32);
+}
+
+TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) {
+ TestTransposeConstant021(2, 70, 35);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
new file mode 100644
index 0000000000..cea9316a6d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -0,0 +1,415 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <initializer_list>
+#include <memory>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class TupleTest : public ClientLibraryTestBase {
+ public:
+ ErrorSpec error_spec_{0.0001};
+};
+
+// Tests the creation of tuple data.
+XLA_TEST_F(TupleTest, TupleCreate) {
+ ComputationBuilder builder(client_, TestName());
+
+ const float constant_scalar = 7.3f;
+ std::initializer_list<float> constant_vector = {1.1f, 2.0f, 3.3f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.1f, 2.2f, 3.5f}, // row 0
+ {4.8f, 5.0f, 6.7f}, // row 1
+ };
+ auto result = builder.Tuple({builder.ConstantR0<float>(constant_scalar),
+ builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR0<float>(constant_scalar).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get(),
+ LiteralUtil::CreateR2<float>(constant_matrix).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Tests the creation of tuple data.
+XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) {
+ ComputationBuilder builder(client_, TestName());
+
+ auto result = builder.Tuple(
+ {builder.ConstantR0<float>(7.0), builder.ConstantR1<float>({})});
+
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR0<float>(7.0).get(),
+ LiteralUtil::CreateR1<float>({}).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Tests the creation of an empty tuple.
+XLA_TEST_F(TupleTest, EmptyTupleCreate) {
+ ComputationBuilder builder(client_, TestName());
+ auto result = builder.Tuple({});
+ auto expected = LiteralUtil::MakeTuple({});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Trivial test for extracting a tuple element with GetTupleElement.
+XLA_TEST_F(TupleTest, GetTupleElement) {
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(constant_matrix), {},
+ error_spec_);
+}
+
+// Trivial test for extracting a tuple element with GetTupleElement.
+XLA_TEST_F(TupleTest, GetTupleElementWithZeroElements) {
+ ComputationBuilder builder(client_, TestName());
+ auto tuple_data = builder.Tuple(
+ {builder.ConstantR1<float>({}),
+ builder.ConstantR2FromArray2D<float>(Array2D<float>(0, 101))});
+ auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 101), {}, error_spec_);
+}
+
+// Extracts both elements from a tuple with GetTupleElement and then adds them
+// together.
+XLA_TEST_F(TupleTest, AddTupleElements) {
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto vector_element = builder.GetTupleElement(tuple_data, 0);
+ auto matrix_element = builder.GetTupleElement(tuple_data, 1);
+ auto vector_shape = builder.GetShape(vector_element).ConsumeValueOrDie();
+ auto matrix_shape = builder.GetShape(matrix_element).ConsumeValueOrDie();
+ auto result = builder.Add(matrix_element, vector_element,
+ /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({
+ {2.f, 4.f, 6.f}, // row 0
+ {5.f, 7.f, 9.f}, // row 1
+ });
+ ASSERT_TRUE(ShapeUtil::ShapeIs(*vector_shape, F32, {3}));
+ ASSERT_TRUE(ShapeUtil::ShapeIs(*matrix_shape, F32, {/*y=*/2, /*x=*/3}));
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+// Extracts both elements from a tuple and then puts them into a new tuple in
+// the opposite order.
+XLA_TEST_F(TupleTest, TupleGTEToTuple) {
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
+ auto expected = LiteralUtil::MakeTuple(
+ {LiteralUtil::CreateR2<float>(constant_matrix).get(),
+ LiteralUtil::CreateR1<float>(constant_vector).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+// Builds two new tuples from an existing tuple (by means of GetTupleElement),
+// then adds up the components of the new tuples.
+XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
+ //
+ // v------ --(GTE 0)-- --(GTE 0)----------
+ // \ / \ / \
+ // (tuple)-- (tuple01)-- \
+ // / | \ / \ \
+ // m------ | --(GTE 1)-- --(GTE 1)------------ \
+ // | \ \
+ // | (add)
+ // | / /
+ // |--------(GTE 1)-- --(GTE 0)------------ /
+ // \ \ / /
+ // \ (tuple10)-- /
+ // \ / \ /
+ // -----(GTE 0)-- --(GTE 1)----------
+ ComputationBuilder builder(client_, TestName());
+ std::initializer_list<float> constant_vector = {1.f, 2.f, 3.f};
+ std::initializer_list<std::initializer_list<float>> constant_matrix = {
+ {1.f, 2.f, 3.f}, // row 0
+ {4.f, 5.f, 6.f}, // row 1
+ };
+ auto tuple_data = builder.Tuple({builder.ConstantR1<float>(constant_vector),
+ builder.ConstantR2<float>(constant_matrix)});
+ auto new_tuple01 = builder.Tuple({builder.GetTupleElement(tuple_data, 0),
+ builder.GetTupleElement(tuple_data, 1)});
+ auto new_tuple10 = builder.Tuple({builder.GetTupleElement(tuple_data, 1),
+ builder.GetTupleElement(tuple_data, 0)});
+ auto vector_from_01 = builder.GetTupleElement(new_tuple01, 0);
+ auto vector_from_10 = builder.GetTupleElement(new_tuple10, 1);
+ auto matrix_from_01 = builder.GetTupleElement(new_tuple01, 1);
+ auto matrix_from_10 = builder.GetTupleElement(new_tuple10, 0);
+
+ auto addvectors = builder.Add(vector_from_01, vector_from_10);
+ auto addmatrices = builder.Add(matrix_from_01, matrix_from_10);
+
+ auto result = builder.Add(addmatrices, addvectors,
+ /*broadcast_dimensions=*/{1});
+
+ Array2D<float> expected({
+ {4.f, 8.f, 12.f}, // row 0
+ {10.f, 14.f, 18.f}, // row 1
+ });
+ ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
+ // Tests a selection between tuples with "false" path taken.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, TuplesInAMap) {
+ Computation tuple_computation;
+ {
+ // tuple_computation(x) = 100 * min(x, x^2) + max(x, x^2) using tuples.
+ //
+ // Need to put a select in there to prevent HLO-level optimizations from
+ // optimizing out the tuples.
+ ComputationBuilder b(client_, "sort_square");
+ auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto x2 = b.Mul(x, x);
+ auto x_smaller_tuple = b.Tuple({x, x2});
+ auto x2_smaller_tuple = b.Tuple({x2, x});
+ auto sorted = b.Select(b.Lt(x, x2), x_smaller_tuple, x2_smaller_tuple);
+ auto smaller = b.GetTupleElement(sorted, 0);
+ auto greater = b.GetTupleElement(sorted, 1);
+ b.Add(greater, b.Mul(b.ConstantR0<float>(100.0f), smaller));
+ auto computation_status = b.Build();
+ ASSERT_IS_OK(computation_status.status());
+ tuple_computation = computation_status.ConsumeValueOrDie();
+ }
+
+ ComputationBuilder b(client_, TestName());
+ auto input = b.ConstantR1<float>({-1.0f, 1.0f, 2.1f});
+ b.Map({input}, tuple_computation);
+ ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
+ // Tests a selection between tuples with "true" path taken.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(true), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec1).get(),
+ LiteralUtil::CreateR1<float>(vec2).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
+ // Tests a selection between tuples but the final result is an element of the
+ // tuple, not the whole tuple.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto element = builder.GetTupleElement(select, 0);
+
+ ComputeAndCompareR1<float>(&builder, vec2, {}, error_spec_);
+}
+
+// Cascaded selects between tuple types.
+XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
+ //
+ // vec1 vec2 vec2 vec1
+ // | | | |
+ // | | | |
+ // (tuple 12) (tuple 21)
+ // \ /
+ // \ /
+ // \ /
+ // true -- --(GTE 0)--(select 1)
+ // \ / |
+ // (pred tuple)-- | --(GTE 0)--
+ // / \ V / \
+ // false -- --(GTE 1)--(select 2)-- --(add)
+ // / \ /
+ // / --(GTE 1)--
+ // /
+ // (tuple 21)
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+
+ auto pred_tuple = builder.Tuple(
+ {builder.ConstantR0<bool>(true), builder.ConstantR0<bool>(false)});
+ auto tuple12 = builder.Tuple(
+ {builder.ConstantR1<float>(vec1), builder.ConstantR1<float>(vec2)});
+ auto tuple21 = builder.Tuple(
+ {builder.ConstantR1<float>(vec2), builder.ConstantR1<float>(vec1)});
+
+ auto select1 =
+ builder.Select(builder.GetTupleElement(pred_tuple, 0), tuple12, tuple21);
+ auto select2 =
+ builder.Select(builder.GetTupleElement(pred_tuple, 1), tuple21, select1);
+ auto result = builder.Add(builder.GetTupleElement(select2, 0),
+ builder.GetTupleElement(select2, 1));
+
+ ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest,
+ DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
+ // Similar to SelectBetweenTuples, but the constants are shared between the
+ // input tuples.
+ ComputationBuilder builder(client_, TestName());
+
+ std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
+ std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
+ auto c1 = builder.ConstantR1<float>(vec1);
+ auto c2 = builder.ConstantR1<float>(vec2);
+ auto tuple12 = builder.Tuple({c1, c2});
+ auto tuple21 = builder.Tuple({c2, c1});
+
+ auto select =
+ builder.Select(builder.ConstantR0<bool>(false), tuple12, tuple21);
+ auto expected =
+ LiteralUtil::MakeTuple({LiteralUtil::CreateR1<float>(vec2).get(),
+ LiteralUtil::CreateR1<float>(vec1).get()});
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, NestedTuples) {
+ ComputationBuilder builder(client_, TestName());
+ auto inner_tuple = builder.Tuple(
+ {builder.ConstantR1<float>({1.0, 2.0}), builder.ConstantR0<float>(42.0)});
+ auto outer_tuple =
+ builder.Tuple({inner_tuple, builder.ConstantR1<float>({22.0, 44.0})});
+
+ auto expected_v1 = LiteralUtil::CreateR1<float>({1.0, 2.0});
+ auto expected_s = LiteralUtil::CreateR0<float>(42.0);
+ auto expected_inner_tuple =
+ LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()});
+ auto expected_v2 = LiteralUtil::CreateR1<float>({22.0, 44.0});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()});
+
+ ComputeAndCompareTuple(&builder, *expected, {}, error_spec_);
+}
+
+XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) {
+ ComputationBuilder builder(client_, TestName());
+
+ Shape data_shape = ShapeUtil::MakeShape(F32, {3});
+ Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({data_shape, data_shape});
+ Shape outer_tuple_shape =
+ ShapeUtil::MakeTupleShape({inner_tuple_shape, data_shape});
+
+ auto input = builder.Parameter(0, outer_tuple_shape, "input");
+ auto gte0 = builder.GetTupleElement(input, 0);
+ auto gte1 = builder.GetTupleElement(gte0, 1);
+ builder.Add(gte1, builder.ConstantR1<float>({10.0, 11.0, 12.0}));
+
+ std::unique_ptr<GlobalData> data =
+ client_
+ ->TransferToServer(*LiteralUtil::MakeTuple({
+ LiteralUtil::MakeTuple(
+ {
+ LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0}).get(),
+ LiteralUtil::CreateR1<float>({4.0, 5.0, 6.0}).get(),
+ })
+ .get(),
+ LiteralUtil::CreateR1<float>({7.0, 8.0, 9.0}).get(),
+ }))
+ .ConsumeValueOrDie();
+
+ std::vector<GlobalData*> arguments = {data.get()};
+ const std::vector<float> expected = {4.0 + 10.0, 5.0 + 11.0, 6.0 + 12.0};
+ ComputeAndCompareR1<float>(&builder, expected, arguments, ErrorSpec(1e-5));
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
new file mode 100644
index 0000000000..fdbaa0d178
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -0,0 +1,179 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class UnaryOpTest : public ClientLibraryTestBase {
+ protected:
+ template <typename T>
+ T inf() {
+ return std::numeric_limits<T>::infinity();
+ }
+ template <typename T>
+ void AbsSize0TestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>({});
+ auto abs = builder.Abs(arg);
+
+ ComputeAndCompareR1<T>(&builder, {}, {});
+ }
+
+ template <typename T>
+ void AbsTestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>({-2, 25, 0, -123, inf<T>(), -inf<T>()});
+ auto abs = builder.Abs(arg);
+
+ ComputeAndCompareR1<T>(&builder, {2, 25, 0, 123, inf<T>(), inf<T>()}, {});
+ }
+
+ template <typename T>
+ void SignTestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>(
+ {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
+ auto sign = builder.Sign(arg);
+
+ ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
+ }
+
+ template <typename T>
+ void SignAbsTestHelper() {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<T>({-2, 25, 0, -123});
+ auto sign = builder.Sign(arg);
+ auto abs = builder.Abs(arg);
+ builder.Sub(builder.Mul(sign, abs), arg);
+
+ ComputeAndCompareR1<T>(&builder, {0, 0, 0, 0}, {});
+ }
+};
+
+template <>
+int UnaryOpTest::inf<int>() {
+ return 2147483647;
+}
+
+XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) {
+ AbsSize0TestHelper<int>();
+ AbsSize0TestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, AbsTestR1) {
+ AbsTestHelper<int>();
+ AbsTestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, AbsTestR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto argi = builder.ConstantR0<int>(-5);
+ auto absi = builder.Abs(argi);
+ auto argf = builder.ConstantR0<float>(-3.0f);
+ auto absf = builder.Abs(argf);
+ auto argf0 = builder.ConstantR0<float>(-0.0f);
+ auto absf0 = builder.Abs(argf0);
+ builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
+ absi, PrimitiveType::F32)));
+
+ ComputeAndCompareR0<float>(&builder, 8.0f, {});
+}
+
+TEST_F(UnaryOpTest, SignTestR0) {
+ ComputationBuilder builder(client_, TestName());
+ auto argi = builder.ConstantR0<int>(-5);
+ auto absi = builder.Sign(argi);
+ auto argf = builder.ConstantR0<float>(-4.0f);
+ auto absf = builder.Sign(argf);
+ auto argf0 = builder.ConstantR0<float>(-0.0f);
+ auto absf0 = builder.Sign(argf0);
+ builder.Add(absf0, builder.Add(absf, builder.ConvertElementType(
+ absi, PrimitiveType::F32)));
+
+ ComputeAndCompareR0<float>(&builder, -2.0f, {});
+}
+
+TEST_F(UnaryOpTest, SignTestR1) {
+ SignTestHelper<int>();
+ SignTestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, SignAbsTestR1) {
+ SignAbsTestHelper<int>();
+ SignAbsTestHelper<float>();
+}
+
+TEST_F(UnaryOpTest, UnsignedAbsTestR1) {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<unsigned int>(
+ {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
+ auto abs = builder.Abs(arg);
+
+ ComputeAndCompareR1<unsigned int>(
+ &builder, {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()}, {});
+}
+
+TEST_F(UnaryOpTest, UnsignedSignTestR1) {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR1<unsigned int>(
+ {2, 25, 0, 123, std::numeric_limits<unsigned int>::max()});
+ auto sign = builder.Sign(arg);
+
+ ComputeAndCompareR1<unsigned int>(&builder, {1, 1, 0, 1, 1}, {});
+}
+
+TEST_F(UnaryOpTest, SignAbsTestR2) {
+ ComputationBuilder builder(client_, TestName());
+ auto arg = builder.ConstantR2<float>({{1.0, -2.0}, {-3.0, 4.0}});
+ auto sign = builder.Sign(arg);
+ auto abs = builder.Abs(arg);
+ builder.Sub(builder.Mul(sign, abs), arg);
+
+ ComputeAndCompareR2<float>(&builder, {{0, 0}, {0, 0}}, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
new file mode 100644
index 0000000000..7f3d7d9cb4
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
@@ -0,0 +1,235 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <numeric>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+class VecOpsReduceTest : public ClientLibraryTestBase {
+ public:
+ VecOpsReduceTest() : builder_(client_, TestName()) {}
+
+ ComputationDataHandle BuildSampleConstantCube() {
+ // clang-format off
+ Array3D<float> x3d({
+ {{1.0, 2.0, 3.0}, // | dim 1 // } plane 0 in dim 0
+ {4.0, 5.0, 6.0}}, // V // }
+ // ---- dim 2 ---->
+ {{1.0, 2.0, 3.0}, // } plane 1 in dim 0
+ {4.0, 5.0, 6.0}},
+ {{1.0, 2.0, 3.0}, // } plane 2 in dim 0
+ {4.0, 5.0, 6.0}}});
+ // clang-format on
+ return builder_.ConstantR3FromArray3D<float>(x3d);
+ }
+
+ ComputationBuilder builder_;
+ ErrorSpec errspec_{1e-3, 0};
+};
+
+TEST_F(VecOpsReduceTest, AddReduceR1F32) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ auto x = builder_.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR0<float>(&builder_, -4.2f, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceBigR1F32) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ std::vector<float> input(3000);
+ std::iota(input.begin(), input.end(), 100.0f);
+
+ auto x = builder_.ConstantR1<float>(input);
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ float expected = std::accumulate(input.begin(), input.end(), 0.0f);
+ ComputeAndCompareR0<float>(&builder_, expected, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, MaxReduceR1F32) {
+ auto max_reducer = CreateScalarMax();
+
+ auto x = builder_.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto max_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR0<float>(&builder_, 2.6f, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, MaxReduceR1F32WithNontrivialInit) {
+ auto max_reducer = CreateScalarMax();
+
+ auto x = builder_.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto max_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(4.0f), max_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR0<float>(&builder_, 4.0f, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR2F32Dim1) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ // clang-format off
+ auto x = builder_.ConstantR2<float>({
+ {1.0, 2.0, 3.0}, // | dim 0
+ {4.0, 5.0, 6.0}}); // |
+ // ------ dim 1 ----------
+ // clang-format on
+
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
+
+ ComputeAndCompareR1<float>(&builder_, {6.0, 15.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR2F32Dim0) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+
+ // clang-format off
+ auto x = builder_.ConstantR2<float>({
+ {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0}});
+ // clang-format on
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ ComputeAndCompareR1<float>(&builder_, {5.0, 7.0, 9.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dim2) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{2});
+
+ Array2D<float> expected_array({{6.0f, 15.0f}, {6.0f, 15.0f}, {6.0f, 15.0f}});
+
+ ComputeAndCompareR2<float>(&builder_, expected_array, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dim1) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1});
+
+ Array2D<float> expected_array(
+ {{5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}, {5.0f, 7.0f, 9.0f}});
+
+ ComputeAndCompareR2<float>(&builder_, expected_array, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dim0) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0});
+
+ Array2D<float> expected_array({{3.0f, 6.0f, 9.0f}, {12.0f, 15.0f, 18.0f}});
+
+ ComputeAndCompareR2<float>(&builder_, expected_array, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dims1and2) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{1, 2});
+
+ ComputeAndCompareR1<float>(&builder_, {21.0, 21.0, 21.0}, {}, errspec_);
+}
+
+XLA_TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and2) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 2});
+
+ ComputeAndCompareR1<float>(&builder_, {18.0, 45.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32Dims0and1) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1});
+
+ ComputeAndCompareR1<float>(&builder_, {15.0, 21.0, 27.0}, {}, errspec_);
+}
+
+TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) {
+ auto sum_reducer = CreateScalarAddComputation(F32, &builder_);
+ auto x = BuildSampleConstantCube();
+ auto add_reduce =
+ builder_.Reduce(x, builder_.ConstantR0<float>(0.0f), sum_reducer,
+ /*dimensions_to_reduce=*/{0, 1, 2});
+
+ ComputeAndCompareR0<float>(&builder_, 63.0, {}, errspec_);
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
new file mode 100644
index 0000000000..d9fc1e1e8f
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -0,0 +1,423 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/array4d.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/stream_executor_no_cuda.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class VecOpsSimpleTest : public ClientLibraryTestBase {
+ public:
+ explicit VecOpsSimpleTest(perftools::gputools::Platform* platform = nullptr)
+ : ClientLibraryTestBase(platform,
+ /*disabled_pass_names=*/{"algsimp", "inline"}) {}
+
+ ErrorSpec error_spec_{0.0001};
+};
+
+TEST_F(VecOpsSimpleTest, ExpTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto exp = builder.Exp(x);
+
+ std::vector<float> expected = {8.1662, 7.4274e-02, 13.4637, 1.8316e-02,
+ 8.1662, 9.9742, 6.7379e-03, 4.0657e-01,
+ 9.0718e-02, 4.9530};
+
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, ExpManyValues) {
+ for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
+ ComputationBuilder builder(client_, TestName());
+ std::vector<float> exponents;
+ for (int i = 0; i < count; ++i) {
+ exponents.push_back(i / static_cast<float>(count));
+ }
+ auto x = builder.ConstantR1<float>(exponents);
+ auto exp = builder.Exp(x);
+
+ std::vector<float> expected;
+ for (float exponent : exponents) {
+ expected.push_back(std::exp(exponent));
+ }
+
+ ComputeAndCompareR1<float>(&builder, expected, {},
+ ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
+ }
+}
+
+TEST_F(VecOpsSimpleTest, ExpIn4D) {
+ ComputationBuilder builder(client_, TestName());
+ Array4D<float> exponents(2, 2, 2, 2);
+
+ std::vector<float> exponents_vector;
+ std::vector<float> expected_vector;
+ for (int i = 0; i < exponents.num_elements(); ++i) {
+ exponents_vector.push_back(static_cast<float>(i) /
+ exponents.num_elements());
+ expected_vector.push_back(std::exp(exponents_vector.back()));
+ }
+ exponents.SetValues(exponents_vector);
+
+ Array4D<float> expected(2, 2, 2, 2, expected_vector);
+
+ auto x = builder.ConstantR4FromArray4D<float>(exponents);
+ auto exp = builder.Exp(x);
+
+ ComputeAndCompareR4<float>(&builder, expected, {},
+ ErrorSpec(/*aabs=*/1e-2, /*arel=*/1e-3));
+}
+
+TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ builder.Neg(x);
+
+ std::vector<float> expected = {-2.1, 2.6, -2.6, 4.0, -2.1,
+ -2.3, 5.0, 0.9, 2.4, -1.6};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
+ builder.Neg(x);
+
+ std::vector<int> expected = {-2, 2, -12, 4, -5, -20, 15, 0, 2, -1};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, NegateUint32Values) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<uint32>(
+ {0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
+ builder.Neg(x);
+ std::vector<uint32> expected = {0, static_cast<uint32>(-1),
+ static_cast<uint32>(-42), 1, 12};
+ ComputeAndCompareR1<uint32>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, SquareTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ builder.SquareF32(x);
+
+ std::vector<float> expected = {4.41, 6.76, 6.76, 16., 4.41,
+ 5.29, 25., 0.81, 5.76, 2.56};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ builder.ReciprocalF32(x);
+
+ std::vector<float> expected = {
+ 0.47619048, -0.38461538, 0.38461538, -0.25, 0.47619048,
+ 0.43478261, -0.2, -1.11111111, -0.41666667, 0.625};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
+ ComputationBuilder builder(client_, TestName());
+ auto add = CreateScalarAddComputation(F32, &builder);
+
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR1<float>(
+ {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ auto max = builder.Map({x, y}, add);
+
+ std::vector<float> expected = {1.7, -3.2, -0.4, -3.8, 5.9,
+ 0.1, -6.8, 4., -1., 2.2};
+ ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, MaxTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR1<float>(
+ {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ auto max = builder.Max(x, y);
+
+ std::vector<float> expected = {2.1, -0.6, 2.6, 0.2, 3.8,
+ 2.3, -1.8, 4.9, 1.4, 1.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
+ // Similar to MaxTenValues, except that the inputs come from params rather
+ // than constants.
+ ComputationBuilder builder(client_, TestName());
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
+ {41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data = CreateR1Parameter<float>(
+ {21.0f, 22.0f, 23.0f, 24.0f}, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto max = builder.Max(v1, v2);
+ ComputeAndCompareR1<float>(&builder, {41.0f, 22.0f, 23.0f, 84.0f},
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
+ // Similar to MaxTenValuesFromParams, except that the data size passed in and
+ // out is large.
+ ComputationBuilder builder(client_, TestName());
+
+ // Number of floats in the data passed into and out of the computation.
+ constexpr int datalen = 15 * 1000;
+
+ // The inputs are initialized with a special pattern where in the first third
+ // of the data v1[i] > v2[i] and elsewhere it's vice versa.
+ std::vector<float> v1vec;
+ std::vector<float> v2vec;
+ std::vector<float> expected_vec;
+ for (int i = 0; i < datalen; ++i) {
+ float smaller = i;
+ float larger = i * 2;
+ if (i < datalen / 3) {
+ v1vec.push_back(larger);
+ v2vec.push_back(smaller);
+ } else {
+ v1vec.push_back(smaller);
+ v2vec.push_back(larger);
+ }
+ expected_vec.push_back(larger);
+ }
+
+ ComputationDataHandle v1, v2;
+ std::unique_ptr<GlobalData> param0_data =
+ CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
+ /*builder=*/&builder, /*data_handle=*/&v1);
+ std::unique_ptr<GlobalData> param1_data =
+ CreateR1Parameter<float>(v2vec, /*parameter_number=*/1, /*name=*/"v2",
+ /*builder=*/&builder, /*data_handle=*/&v2);
+
+ auto max = builder.Max(v1, v2);
+ ComputeAndCompareR1<float>(&builder, expected_vec,
+ {param0_data.get(), param1_data.get()},
+ error_spec_);
+}
+
+TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR0<float>(0);
+ auto max = builder.Max(x, y);
+
+ std::vector<float> expected = {2.1, 0.0, 2.6, 0.0, 2.1,
+ 2.3, 0.0, 0.0, 0.0, 1.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MinTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto y = builder.ConstantR1<float>(
+ {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6});
+ auto min = builder.Min(x, y);
+
+ std::vector<float> expected = {-0.4, -2.6, -3.0, -4.0, 2.1,
+ -2.2, -5.0, -0.9, -2.4, 0.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0);
+ auto one = builder.ConstantR0<float>(1);
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ auto clamp = builder.Min(builder.Max(x, zero), one);
+
+ std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
+ 0.9, 0.0, 0.1, 0.0, 0.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<float>(0);
+ auto one = builder.ConstantR0<float>(1);
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ auto clamp = builder.Clamp(zero, x, one);
+
+ std::vector<float> expected = {1.0, 0.0, 1.0, 0.3, 1.0,
+ 0.9, 0.0, 0.1, 0.0, 0.6};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
+ auto one = builder.ConstantR1<float>({1.0f, 1.0f});
+ auto x = builder.ConstantR1<float>({2.1, -2.6});
+ auto clamp = builder.Clamp(zero, x, one);
+
+ std::vector<float> expected = {1.0, 0.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
+ ComputationBuilder builder(client_, TestName());
+ auto one = builder.ConstantR0<float>(1);
+ auto two = builder.ConstantR0<float>(2);
+ auto x = builder.ConstantR1<float>(
+ {2.1, -2.6, 2.6, 0.3, 3.1, 0.9, -5.0, 0.1, -2.4, 0.6});
+ auto clamp = builder.Clamp(one, x, two);
+
+ std::vector<float> expected = {2.0, 1.0, 2.0, 1.0, 2.0,
+ 1.0, 1.0, 1.0, 1.0, 1.0};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+TEST_F(VecOpsSimpleTest, MapTenValues) {
+ Computation add_half;
+ {
+ // add_half(x) = x + 0.5
+ ComputationBuilder builder(client_, "add_half");
+ auto x_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
+ auto half = builder.ConstantR0<float>(0.5);
+ builder.Add(x_value, half);
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ add_half = computation_status.ConsumeValueOrDie();
+ }
+
+ Computation clamp;
+ {
+ // clamp(y) = clamp<0,5>(y)
+ ComputationBuilder builder(client_, "clamp");
+ auto y_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto clamped = builder.Clamp(zero, y_value, builder.ConstantR0<float>(5));
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ clamp = computation_status.ConsumeValueOrDie();
+ }
+
+ Computation mult_relu_add;
+ {
+ // mult_relu_add(z) = clamp(add_half(2 * max(z, 0)))
+ ComputationBuilder builder(client_, "mult_relu_add");
+ auto z_value =
+ builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
+ auto zero = builder.ConstantR0<float>(0.0);
+ auto two = builder.ConstantR0<float>(2.0);
+ auto max = builder.Max(z_value, zero);
+ auto mult = builder.Mul(two, max);
+ auto inner = builder.Map({mult}, add_half);
+ builder.Map({inner}, clamp);
+ auto computation_status = builder.Build();
+ ASSERT_IS_OK(computation_status.status());
+ mult_relu_add = computation_status.ConsumeValueOrDie();
+ }
+
+ ComputationBuilder builder(client_, "map10");
+ {
+ auto x = builder.ConstantR1<float>(
+ {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
+ auto activations = builder.Map({x}, mult_relu_add);
+ }
+
+ std::vector<float> expected = {4.7, 0.5, 5.0, 0.5, 4.7,
+ 5.0, 0.5, 0.5, 0.5, 3.7};
+ ComputeAndCompareR1<float>(&builder, expected, {});
+}
+
+XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<int32>({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4});
+ auto y = builder.ConstantR0<int32>(3);
+ builder.Rem(x, y);
+
+ std::vector<int32> expected = {-2, -1, 0, -2, -1, 0, 1, 2, 0, 1};
+ ComputeAndCompareR1<int32>(&builder, expected, {});
+}
+
+XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<bool>({false, true});
+ auto y = builder.ConstantR1<bool>({true, false});
+ builder.Eq(x, y);
+
+ std::array<bool, 2> expected = {{false, false}};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR1<bool>({false, true});
+ auto y = builder.ConstantR1<bool>({true, false});
+ builder.Ne(x, y);
+
+ std::array<bool, 2> expected = {{true, true}};
+ ComputeAndCompareR1<bool>(&builder, expected, {});
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
new file mode 100644
index 0000000000..7820bc363d
--- /dev/null
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -0,0 +1,395 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/platform_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace se = ::perftools::gputools;
+
+namespace xla {
+namespace {
+
+class WhileTest : public ClientLibraryTestBase {};
+
+// Tests a while node when the result type T is S32.
+//
+// int32 result = 0;
+// while (result < 5) {
+// result = result + 1;
+// }
+TEST_F(WhileTest, WhileWithScalarResult) {
+ auto result_shape = ShapeUtil::MakeShape(S32, {});
+
+ // Create a computation for the condition: repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ builder.Gt(builder.ConstantR0<int32>(5), prev);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR0<int32>(1);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR0<int32>(0);
+ auto result = builder.While(condition, body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+
+ ComputeAndCompareR0<int32>(&builder, 5, {});
+}
+
+// Tests a while node when the result type T is a vector.
+//
+// All constants are chosen to produce exact results.
+// vector<float> result(0);
+// while (result.sum() < 15.5f) {
+// result = result + vector<float>(0);
+// }
+// TODO(b/29185393): does not terminate on CPU.
+TEST_F(WhileTest, DISABLED_WhileWithEmptyVectorResult) {
+ Shape result_shape = ShapeUtil::MakeShape(F32, {0});
+
+ // Create a computation for the reduction.
+ Computation add;
+ {
+ ComputationBuilder builder(client_, "add");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Add(x, y);
+ add = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the condition.
+ // Repeat until the sum of the result vector is less than 15.5f.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add a constant vector of 1.f to the result vector.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR1<float>({});
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.ConstantR1<float>({});
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = " << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ ComputeAndCompareR1<float>(&builder, {}, {}, ErrorSpec(0.0001));
+}
+
+// Tests a while node when the result type T is a vector.
+//
+// All constants are chosen to produce exact results.
+// vector<float> result(8, 0.0f);
+// while (result.sum() < 15.5f) {
+// result = result + vector<float>(8, 0.125f);
+// }
+TEST_F(WhileTest, WhileWithVectorResult) {
+ Shape result_shape = ShapeUtil::MakeShape(F32, {8});
+
+ // Create a computation for the reduction.
+ Computation add;
+ {
+ ComputationBuilder builder(client_, "add");
+ auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
+ auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "y");
+ builder.Add(x, y);
+ add = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the condition.
+ // Repeat until the sum of the result vector is less than 5.5f.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto sum = builder.Reduce(prev, builder.ConstantR0<float>(0.0f), add,
+ /*dimensions_to_reduce=*/{0});
+ auto test = builder.Gt(builder.ConstantR0<float>(15.5f), sum);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add a constant vector of 1.f to the result vector.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto input = builder.ConstantR1<float>(8, 0.125f);
+ auto result = builder.Add(input, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.ConstantR1<float>(8, 0.f);
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = " << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ // Individual elements with increase by 1/8 each time through the loop, so
+ // the sum will increase by 1.0. It will first be >15.5 when the elements
+ // have all reached 2.0.
+ std::vector<float> expected = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f};
+ ComputeAndCompareR1<float>(&builder, expected, {}, ErrorSpec(0.0001));
+}
+
+// Tests a while node when the result type T is a Tuple.
+//
+// tuple<int32, vector<float>> result(0, vector<float>(10, 0.0f));
+// while (get<0>(result) < 5) {
+// get<0>(result) = get<0>(result) + 1;
+// get<1>(result) = get<1>(result) + vector<float>(10, 1.0f);
+// }
+TEST_F(WhileTest, WhileWithTupleResult) {
+ std::vector<Shape> shape_elements = {ShapeUtil::MakeShape(S32, {}),
+ ShapeUtil::MakeShape(F32, {10})};
+ Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements);
+
+ // Create a computation for the condition.
+ // Repeat for 5 iterations.
+ Computation condition;
+ {
+ ComputationBuilder builder(client_, "condition");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ builder.Gt(builder.ConstantR0<int32>(5), iteration);
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a computation for the body.
+ // Add 1 to the iteration variable and add a constant vector of 1.0f to
+ // the weight variable, both of which are tuple elements.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, result_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ auto weights = builder.GetTupleElement(prev, 1);
+ auto input = builder.ConstantR1<float>(10, 1.f);
+ auto new_weights = builder.Add(weights, input);
+ auto result = builder.Tuple(
+ {builder.Add(iteration, builder.ConstantR0<int32>(1)), new_weights});
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ ComputationBuilder builder(client_, "while");
+ auto init = builder.Tuple(
+ {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
+ auto result = builder.While(condition, body, init);
+ VLOG(2) << "while = " << ShapeUtil::HumanString(
+ *builder.GetShape(result).ConsumeValueOrDie());
+
+ auto expected_counter = LiteralUtil::CreateR0<int32>(5);
+ auto expected_data = LiteralUtil::CreateR1<float>(
+ {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f});
+ auto expected =
+ LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()});
+ VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape());
+ ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
+}
+
+// Tests a while node when the result type T is a vector of S32.
+//
+// int32 result = (0, 0, 0, 0, 0, 0);
+// while (result[0] < count) {
+// result += (1, U[0, 100], U[0, 100], U[0, 100], U[0, 100], U[0, 100]);
+// }
+//
+// This test misuses a vector to represent a pair:
+// ((iteration, (random vector))).
+//
+// Note: this test currently only tests generating random values within a loop.
+// Per backend the values generated can be different as the different backends
+// use different random number generators.
+// TODO(b/32240857): Extend test to verify outputs.
+TEST_F(WhileTest, WhileWithPrngScalarResult) {
+ auto v6s32 = ShapeUtil::MakeShape(S32, {6});
+
+ // Create a computation for the condition: repeat for count iterations.
+ auto build_condition = [this, v6s32](int count) {
+ ComputationBuilder builder(client_, TestName());
+ auto prev = builder.Reshape(
+ builder.Slice(builder.Parameter(0, v6s32, "prev"), {0}, {1}), {0}, {});
+ builder.Gt(builder.ConstantR0<int32>(count), prev);
+ return builder.Build().ConsumeValueOrDie();
+ };
+
+ // Create a computation for the body: add 1 to the result variable.
+ Computation body;
+ {
+ ComputationBuilder builder(client_, "body");
+ auto prev = builder.Parameter(0, v6s32, "prev");
+ auto inc = builder.ConcatInDim(
+ {builder.ConstantR1<int32>({1}),
+ builder.RngUniform(builder.ConstantR0<int32>(0),
+ builder.ConstantR0<int32>(100),
+ ShapeUtil::MakeShape(S32, {5}))},
+ 0);
+ auto result = builder.Add(inc, prev);
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While node with computations for the condition and the body.
+ auto while_loop = [this, &body, build_condition](int count) {
+ ComputationBuilder builder(client_, TestName());
+ auto init = builder.ConstantR1<int32>({0, 0, 0, 0, 0, 0});
+ auto result = builder.While(build_condition(count), body, init);
+ auto shape = builder.GetShape(result).ConsumeValueOrDie();
+ return builder.Build();
+ };
+
+ for (int i = 1; i < 4; ++i) {
+ TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i));
+ TF_ASSIGN_OR_ASSERT_OK(auto result,
+ client_->ExecuteAndTransfer(computation, {}, nullptr,
+ nullptr, /*seed=*/65));
+ }
+}
+
+void BM_WhileLoop(int num_iters) {
+ // Benchmark a simple kernel to measure while loop overheads.
+ tensorflow::testing::StopTiming();
+
+ se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie();
+ auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie();
+ StreamExecutorMemoryAllocator allocator(platform, executors);
+ LocalClient* client =
+ ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
+
+ Shape loop_state_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {10})});
+
+ // Create while condition computation with 'loop_limit'.
+ const int32 loop_limit = 100;
+ Computation condition;
+ {
+ ComputationBuilder builder(client, "condition");
+ auto prev = builder.Parameter(0, loop_state_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ builder.Lt(iteration, builder.ConstantR0<int32>(loop_limit));
+ condition = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create while body computation with unit loop increment.
+ Computation body;
+ {
+ ComputationBuilder builder(client, "body");
+ auto prev = builder.Parameter(0, loop_state_shape, "prev");
+ auto iteration = builder.GetTupleElement(prev, 0);
+ auto weights = builder.GetTupleElement(prev, 1);
+ auto one = builder.ConstantR0<int32>(1);
+ auto next_iteration = builder.Add(iteration, one);
+ auto one_vec = builder.ConstantR1<float>(10, 1.f);
+ auto new_weights = builder.Add(weights, one_vec);
+ auto result = builder.Tuple({next_iteration, new_weights});
+ body = builder.Build().ConsumeValueOrDie();
+ }
+
+ // Create a While instruction.
+ ComputationBuilder builder(client, "while");
+ auto init = builder.Tuple(
+ {builder.ConstantR0<int32>(0), builder.ConstantR1<float>(10, 0.f)});
+ builder.While(condition, body, init);
+ auto computation = builder.Build().ConsumeValueOrDie();
+
+ // Run some warm-up executions.
+ LocalExecuteOptions options;
+ options.set_allocator(&allocator);
+ const int kWarmups = 2;
+ for (int i = 0; i < kWarmups; ++i) {
+ auto result = client->ExecuteLocally(computation, {}, options);
+ ASSERT_TRUE(result.ok());
+ }
+
+ // Run benchmark.
+ tensorflow::testing::StartTiming();
+ for (int i = 0; i < num_iters; ++i) {
+ auto result = client->ExecuteLocally(computation, {}, options);
+ ASSERT_TRUE(result.ok());
+ }
+}
+
+// TODO(b/32470510): Benchmark fails on parallel CPU backend.
+#ifndef XLA_TEST_BACKEND_CPU_PARALLEL
+BENCHMARK(BM_WhileLoop);
+#endif
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ std::vector<tensorflow::Flag> flag_list;
+ xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result) {
+ LOG(ERROR) << "\n" << usage;
+ return 2;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
+ return 2;
+ }
+ tensorflow::testing::RunBenchmarks();
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc
new file mode 100644
index 0000000000..7876272467
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_reader.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/text_literal_reader.h"
+
+#include <limits>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadPath(
+ tensorflow::StringPiece path) {
+ CHECK(!path.ends_with(".gz"))
+ << "TextLiteralReader no longer supports reading .gz files";
+ std::unique_ptr<tensorflow::RandomAccessFile> file;
+ Status s =
+ tensorflow::Env::Default()->NewRandomAccessFile(path.ToString(), &file);
+ if (!s.ok()) {
+ return s;
+ }
+
+ TextLiteralReader reader(file.release());
+ return reader.ReadAllLines();
+}
+
+TextLiteralReader::TextLiteralReader(tensorflow::RandomAccessFile* file)
+ : file_(file) {}
+
+namespace {
+// This is an optimized version of tensorflow::str_util::Split which uses
+// StringPiece for the delimited strings and uses an out parameter for the
+// result to avoid vector creation/destruction.
+void SplitByDelimToStringPieces(tensorflow::StringPiece text, char delim,
+ std::vector<tensorflow::StringPiece>* result) {
+ result->clear();
+
+ if (text.empty()) {
+ return;
+ }
+
+ // The following loop is a little strange: its bound is text.size() + 1
+ // instead of the more typical text.size().
+ // The final iteration of the loop (when i is equal to text.size()) handles
+ // the trailing token.
+ size_t token_start = 0;
+ for (size_t i = 0; i < text.size() + 1; i++) {
+ if (i == text.size() || text[i] == delim) {
+ tensorflow::StringPiece token(text.data() + token_start, i - token_start);
+ result->push_back(token);
+ token_start = i + 1;
+ }
+ }
+}
+} // namespace
+
+StatusOr<std::unique_ptr<Literal>> TextLiteralReader::ReadAllLines() {
+ tensorflow::io::RandomAccessInputStream stream(file_.get());
+ tensorflow::io::BufferedInputStream buf(&stream, 65536);
+ string shape_string;
+ Status s = buf.ReadLine(&shape_string);
+ if (!s.ok()) {
+ return s;
+ }
+
+ tensorflow::StringPiece sp(shape_string);
+ if (tensorflow::str_util::RemoveWhitespaceContext(&sp) > 0) {
+ string tmp = sp.ToString();
+ shape_string = tmp;
+ }
+ TF_ASSIGN_OR_RETURN(Shape shape, ShapeUtil::ParseShapeString(shape_string));
+ if (shape.element_type() != F32) {
+ return Unimplemented(
+ "unsupported element type for text literal reading: %s",
+ ShapeUtil::HumanString(shape).c_str());
+ }
+
+ auto result = MakeUnique<Literal>();
+ const float fill = std::numeric_limits<float>::quiet_NaN();
+ LiteralUtil::PopulateWithValue<float>(fill, AsInt64Slice(shape.dimensions()),
+ result.get());
+ std::vector<tensorflow::StringPiece> pieces;
+ std::vector<tensorflow::StringPiece> coordinates;
+ std::vector<int64> coordinate_values;
+ string line;
+ while (buf.ReadLine(&line).ok()) {
+ SplitByDelimToStringPieces(line, ':', &pieces);
+ tensorflow::StringPiece coordinates_string = pieces[0];
+ tensorflow::StringPiece value_string = pieces[1];
+ tensorflow::str_util::RemoveWhitespaceContext(&coordinates_string);
+ tensorflow::str_util::RemoveWhitespaceContext(&value_string);
+ if (!coordinates_string.Consume("(")) {
+ return InvalidArgument(
+ "expected '(' at the beginning of coordinates: \"%s\"", line.c_str());
+ }
+ if (!tensorflow::str_util::ConsumeSuffix(&coordinates_string, ")")) {
+ return InvalidArgument("expected ')' at the end of coordinates: \"%s\"",
+ line.c_str());
+ }
+ float value;
+ if (!tensorflow::strings::safe_strtof(value_string.ToString().c_str(),
+ &value)) {
+ return InvalidArgument("could not parse value as float: \"%s\"",
+ value_string.ToString().c_str());
+ }
+ SplitByDelimToStringPieces(coordinates_string, ',', &coordinates);
+ coordinate_values.clear();
+ for (tensorflow::StringPiece piece : coordinates) {
+ int64 coordinate_value;
+ if (!tensorflow::strings::safe_strto64(piece, &coordinate_value)) {
+ return InvalidArgument(
+ "could not parse coordinate member as int64: \"%s\"",
+ piece.ToString().c_str());
+ }
+ coordinate_values.push_back(coordinate_value);
+ }
+ if (coordinate_values.size() != shape.dimensions_size()) {
+ return InvalidArgument(
+ "line did not have expected number of coordinates; want %d got %zu: "
+ "\"%s\"",
+ shape.dimensions_size(), coordinate_values.size(), line.c_str());
+ }
+ LiteralUtil::Set<float>(result.get(), coordinate_values, value);
+ }
+ return std::move(result);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/text_literal_reader.h b/tensorflow/compiler/xla/text_literal_reader.h
new file mode 100644
index 0000000000..3cfbb2c7fb
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_reader.h
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_READER_H_
+#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_READER_H_
+
+#include <memory>
+
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Reads a textual literal from a file path. The format of the file must be:
+//
+// f32[1,2,3,4]
+// (0, 0, 0, 0): 1.234
+// (0, 0, 0, 1): 0xf00p-2
+// ...
+//
+// Note that for floating values the hex output (as in the second value above)
+// will more precisely convey the exact values.
+class TextLiteralReader {
+ public:
+ // See class comment -- reads a file in its entirety (there must be only one
+ // literal in the text file path provided).
+ static StatusOr<std::unique_ptr<Literal>> ReadPath(
+ tensorflow::StringPiece path);
+
+ private:
+ // Ownership of file is transferred.
+ explicit TextLiteralReader(tensorflow::RandomAccessFile* file);
+
+ // Parses a shape string on the first line, followed by lines of values to the
+ // end of the file.
+ StatusOr<std::unique_ptr<Literal>> ReadAllLines();
+
+ // Owns the file being read
+ std::unique_ptr<tensorflow::RandomAccessFile> file_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralReader);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_READER_H_
diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc
new file mode 100644
index 0000000000..94d0f2646b
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_reader_test.cc
@@ -0,0 +1,58 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/text_literal_reader.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(TextLiteralReaderTest, ReadsR3File) {
+ string contents = R"(f32[1,2,3]
+(0,0,0): 42.5
+(0,0,1): 43.5
+(0,0,2): 44.5
+(0,1,0): 45.5
+(0,1,1): 46.5
+(0,1,2): 47.5
+)";
+
+ string fname = tensorflow::testing::TmpDir() + "/ReadsR3File.data.txt";
+ EXPECT_TRUE(
+ tensorflow::WriteStringToFile(tensorflow::Env::Default(), fname, contents)
+ .ok());
+
+ std::unique_ptr<Literal> literal =
+ TextLiteralReader::ReadPath(fname).ConsumeValueOrDie();
+ EXPECT_TRUE(
+ ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape()));
+ EXPECT_EQ(42.5, LiteralUtil::Get<float>(*literal, {0, 0, 0}));
+ EXPECT_EQ(43.5, LiteralUtil::Get<float>(*literal, {0, 0, 1}));
+ EXPECT_EQ(44.5, LiteralUtil::Get<float>(*literal, {0, 0, 2}));
+ EXPECT_EQ(45.5, LiteralUtil::Get<float>(*literal, {0, 1, 0}));
+ EXPECT_EQ(46.5, LiteralUtil::Get<float>(*literal, {0, 1, 1}));
+ EXPECT_EQ(47.5, LiteralUtil::Get<float>(*literal, {0, 1, 2}));
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
new file mode 100644
index 0000000000..a5097e41cb
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -0,0 +1,64 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/text_literal_writer.h"
+
+#include <string>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+/* static */ tensorflow::Status TextLiteralWriter::WriteToPath(
+ const Literal& literal, tensorflow::StringPiece path) {
+ std::unique_ptr<tensorflow::WritableFile> f;
+ auto s = tensorflow::Env::Default()->NewWritableFile(path.ToString(), &f);
+ if (!s.ok()) {
+ return s;
+ }
+
+ s = f->Append(ShapeUtil::HumanString(literal.shape()) + "\n");
+ if (!s.ok()) {
+ return s;
+ }
+
+ tensorflow::Status status;
+ tensorflow::WritableFile* f_ptr = f.get();
+ LiteralUtil::EachCellAsString(
+ literal, [f_ptr, &status](tensorflow::gtl::ArraySlice<int64> indices,
+ const string& value) {
+ if (!status.ok()) {
+ return;
+ }
+ string coordinates = tensorflow::strings::StrCat(
+ "(", tensorflow::str_util::Join(indices, ", "), ")");
+
+ status = f_ptr->Append(
+ tensorflow::strings::StrCat(coordinates, ": ", value, "\n"));
+ });
+ auto ignored = f->Close();
+ return status;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/text_literal_writer.h b/tensorflow/compiler/xla/text_literal_writer.h
new file mode 100644
index 0000000000..545bd22da9
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_writer.h
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
+#define TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace xla {
+
+// Writes a literal to textual form at a file path.
+//
+// The format is roughly:
+//
+// f32[1,2,3,4]
+// (0, 0, 0, 0): 1.234
+// (0, 0, 0, 1): 0xf00p-2
+// ...
+//
+// This should be readable by xla::TextLiteralReader.
+class TextLiteralWriter {
+ public:
+ static tensorflow::Status WriteToPath(const Literal& literal,
+ tensorflow::StringPiece path);
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(TextLiteralWriter);
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TEXT_LITERAL_WRITER_H_
diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc
new file mode 100644
index 0000000000..9dce4d13bb
--- /dev/null
+++ b/tensorflow/compiler/xla/text_literal_writer_test.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/text_literal_writer.h"
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+TEST(TextLiteralWriterTest, WritesFloatLiteral) {
+ auto literal = LiteralUtil::CreateR2<float>({
+ {3.14, 2.17}, {1.23, 4.56},
+ });
+ string path =
+ tensorflow::io::JoinPath(tensorflow::testing::TmpDir(), "/whatever");
+ ASSERT_IS_OK(TextLiteralWriter::WriteToPath(*literal, path));
+ string contents;
+ TF_CHECK_OK(tensorflow::ReadFileToString(tensorflow::Env::Default(), path,
+ &contents));
+ const string expected = R"(f32[2,2]
+(0, 0): 3.14
+(0, 1): 2.17
+(1, 0): 1.23
+(1, 1): 4.56
+)";
+ EXPECT_EQ(expected, contents);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD
new file mode 100644
index 0000000000..46eab7f02b
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/BUILD
@@ -0,0 +1,191 @@
+# Tools and utilities that aid in XLA development and usage.
+
+licenses(["notice"]) # Apache 2.0
+
+package(default_visibility = ["//tensorflow/compiler/xla:internal"])
+
+# Filegroup used to collect source files for dependency checking.
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+ visibility = ["//tensorflow/compiler/xla:internal"],
+)
+
+cc_binary(
+ name = "hex_floats_to_packed_literal",
+ srcs = ["hex_floats_to_packed_literal.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "dumped_computation_to_graphviz_library",
+ srcs = ["dumped_computation_to_graphviz.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/legacy_flags:service_flags",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_binary(
+ name = "dumped_computation_to_graphviz",
+ deps = [
+ ":dumped_computation_to_graphviz_library",
+ ],
+)
+
+cc_binary(
+ name = "show_signature",
+ srcs = ["show_signature.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "replay_computation_library",
+ srcs = ["replay_computation.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/lib:testing",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = True,
+)
+
+cc_binary(
+ name = "replay_computation_cpu",
+ deps = [
+ ":replay_computation_library",
+ "//tensorflow/compiler/xla/service:cpu_plugin",
+ ],
+)
+
+cc_binary(
+ name = "replay_computation_gpu",
+ deps = [
+ ":replay_computation_library",
+ "//tensorflow/compiler/xla/service:gpu_plugin",
+ ],
+)
+
+cc_binary(
+ name = "show_literal",
+ srcs = ["show_literal.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_binary(
+ name = "convert_computation",
+ srcs = ["convert_computation.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_binary(
+ name = "show_text_literal",
+ srcs = ["show_text_literal.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:text_literal_reader",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_binary(
+ name = "dumped_computation_to_text",
+ srcs = ["dumped_computation_to_text.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_binary(
+ name = "dumped_computation_to_operation_list",
+ srcs = ["dumped_computation_to_operation_list.cc"],
+ deps = [
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client",
+ "//tensorflow/compiler/xla/client:client_library",
+ "//tensorflow/compiler/xla/client:computation",
+ "//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/service",
+ "//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/core:lib",
+ ],
+)
+
+# -----------------------------------------------------------------------------
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/compiler/xla/tools/convert_computation.cc b/tensorflow/compiler/xla/tools/convert_computation.cc
new file mode 100644
index 0000000000..fe03a6e7bd
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/convert_computation.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: convert_computation <txt2bin|bin2txt> serialized_computation_proto
+//
+// bin2txt spits out the result to stdout. txt2bin modifies the file in place.
+
+#include <stdio.h>
+#include <unistd.h>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace xla {
+namespace tools {
+
+void RealMain(const string& mode, const string& path) {
+ SessionModule module;
+ tensorflow::Env* env = tensorflow::Env::Default();
+ if (mode == "txt2bin") {
+ TF_CHECK_OK(tensorflow::ReadTextProto(env, path, &module));
+ TF_CHECK_OK(tensorflow::WriteBinaryProto(env, path, module));
+ } else if (mode == "bin2txt") {
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(env, path, &module));
+ string out;
+ tensorflow::protobuf::TextFormat::PrintToString(module, &out);
+ fprintf(stdout, "%s", out.c_str());
+ } else {
+ LOG(QFATAL) << "unknown mode for computation conversion: " << mode;
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ QCHECK_EQ(argc, 3) << "usage: " << argv[0] << " <txt2bin|bin2txt> <path>";
+ xla::tools::RealMain(argv[1], argv[2]);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
new file mode 100644
index 0000000000..10efa9f3e8
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_graphviz.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: dumped_computation_to_graphviz some_binary_snapshot_proto*
+//
+// Dumps a graphviz URL for a snapshot computation to the command line.
+//
+// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// ServiceInterface::SnapshotComputation to disk.
+//
+// The GraphViz URL is placed into the log stderr, whereas computation
+// statistics are printed on stdout (implementation note: getting computation
+// statistics is how we trigger compilation to split out a GraphViz URL).
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/legacy_flags/service_flags.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace tools {
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+ Client* client = ClientLibrary::LocalClientOrDie();
+ for (char* arg : args) {
+ SessionModule module;
+ TF_CHECK_OK(
+ tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
+ Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+ ComputationStats stats =
+ client->GetComputationStats(computation).ConsumeValueOrDie();
+ fprintf(stdout, ">>> %s :: %s\n", arg, stats.DebugString().c_str());
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ xla::legacy_flags::ServiceFlags* flags = xla::legacy_flags::GetServiceFlags();
+ flags->xla_generate_hlo_graph = ".*";
+ flags->xla_hlo_graph_layout = true;
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
new file mode 100644
index 0000000000..4c242abc9b
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
@@ -0,0 +1,111 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Dumps out the operations that are present in a serialized computation.
+
+#include <iostream>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace tools {
+
+class OperationDumper : public DfsHloVisitorWithDefault {
+ public:
+ explicit OperationDumper(const string& path) : path_(path) {}
+
+ Status DefaultAction(HloInstruction* hlo) override {
+ string params = tensorflow::str_util::Join(
+ hlo->operands(), ", ", [](string* out, const HloInstruction* operand) {
+ tensorflow::strings::StrAppend(
+ out, ShapeUtil::HumanString(operand->shape()));
+ });
+ // Spit `op_name(params...) -> result_type :: path` to stdout.
+ std::cout << tensorflow::strings::Printf(
+ "%s :: (%s) -> %s :: %s\n", HloOpcodeString(hlo->opcode()).c_str(),
+ params.c_str(), ShapeUtil::HumanString(hlo->shape()).c_str(),
+ path_.c_str());
+ return Status::OK();
+ }
+
+ private:
+ string path_;
+};
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalService* local_service =
+ ClientLibrary::GetXlaService(client->platform());
+ for (char* arg : args) {
+ SessionModule session_module;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
+ &session_module));
+ auto computation_status = client->LoadSnapshot(session_module);
+ if (!computation_status.ok()) {
+ fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
+ computation_status.status().ToString().c_str());
+ continue;
+ }
+ Computation computation = computation_status.ConsumeValueOrDie();
+
+ std::unique_ptr<ProgramShape> program_shape =
+ client->GetComputationShape(computation).ConsumeValueOrDie();
+
+ std::vector<const Shape*> layouts;
+ for (int i = 0; i < program_shape->parameters_size(); ++i) {
+ layouts.push_back(&program_shape->parameters(i));
+ }
+ StatusOr<std::unique_ptr<Executable>> executable =
+ local_service->CompileExecutable(
+ computation.handle(), layouts, &program_shape->result(),
+ /*device_ordinal=*/0, /*has_hybrid_result=*/true);
+
+ const HloModule& module = executable.ValueOrDie()->module();
+
+ OperationDumper dumper(arg);
+ for (auto& computation : module.computations()) {
+ TF_CHECK_OK(computation->root_instruction()->Accept(&dumper));
+ }
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
new file mode 100644
index 0000000000..8b96e13489
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/service.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace tools {
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+ LocalClient* client = ClientLibrary::LocalClientOrDie();
+ LocalService* local_service =
+ ClientLibrary::GetXlaService(client->platform());
+ for (char* arg : args) {
+ SessionModule session_module;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg,
+ &session_module));
+ auto computation_status = client->LoadSnapshot(session_module);
+ if (!computation_status.ok()) {
+ fprintf(stderr, "could not load snapshot for %s: %s\n", arg,
+ computation_status.status().ToString().c_str());
+ continue;
+ }
+ Computation computation = computation_status.ConsumeValueOrDie();
+
+ std::unique_ptr<ProgramShape> program_shape =
+ client->GetComputationShape(computation).ConsumeValueOrDie();
+
+ std::vector<const Shape*> layouts;
+ for (int i = 0; i < program_shape->parameters_size(); ++i) {
+ layouts.push_back(&program_shape->parameters(i));
+ }
+ StatusOr<std::unique_ptr<Executable>> executable =
+ local_service->CompileExecutable(
+ computation.handle(), layouts, &program_shape->result(),
+ /*device_ordinal=*/0, /*has_hybrid_result=*/true);
+
+ const HloModule& module = executable.ValueOrDie()->module();
+
+ fprintf(stdout, "HLO for %s backend:\n%s\n",
+ local_service->backend().platform()->Name().c_str(),
+ module.ToString().c_str());
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
new file mode 100644
index 0000000000..eb7bff053b
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/hex_floats_to_packed_literal.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdio.h>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/casts.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+using xla::string;
+
+int main(int argc, char** argv) {
+ // Flags
+ string input_file = "";
+ string output_file = "";
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("input_file", &input_file, "file to convert"),
+ tensorflow::Flag("output_file", &output_file, "converted file"),
+ };
+ string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc != 1 || !parse_ok) {
+ LOG(QFATAL) << usage;
+ }
+
+ if (input_file.empty()) {
+ LOG(QFATAL) << "--input_file is required";
+ }
+ if (output_file.empty()) {
+ LOG(QFATAL) << "--output_file is required";
+ }
+
+ std::unique_ptr<tensorflow::RandomAccessFile> file;
+ TF_CHECK_OK(
+ tensorflow::Env::Default()->NewRandomAccessFile(input_file, &file));
+
+ std::vector<float> floats;
+ string line;
+ tensorflow::io::RandomAccessInputStream stream(file.get());
+ tensorflow::io::BufferedInputStream buf(&stream, 1048576);
+ while (buf.ReadLine(&line).ok()) {
+ float value;
+ QCHECK(sscanf(line.c_str(), "%f", &value) != 1) << "invalid float value: "
+ << line;
+ floats.push_back(value);
+ }
+
+ tensorflow::StringPiece content(
+ tensorflow::bit_cast<const char*>(floats.data()),
+ floats.size() * sizeof(float));
+ TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(),
+ output_file, content));
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
new file mode 100644
index 0000000000..ffb2d5aefb
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -0,0 +1,129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: replay_computation some_binary_snapshot_proto*
+//
+// Replays computations and shows the results on the command line.
+//
+// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// ServiceInterface::SnapshotComputation to disk.
+//
+// Computations that require arguments can be replayed using fake data by
+// passing --use_fake_data on the command line. If the real data is available
+// in the proto and --use_fake_data is false, the real data is used.
+//
+// The output format is:
+//
+// file_path: computation_name :: type:literal_str
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/global_data.h"
+#include "tensorflow/compiler/xla/client/lib/testing.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace xla {
+namespace tools {
+
+// Invokes the given computation passing arbitrary data for every (unbound)
+// parameter if use_fake_data, Otherwise use recorded data if available.
+StatusOr<std::unique_ptr<Literal>> ReplayComputation(
+ const SessionModule& module, bool use_fake_data, Client* client) {
+ TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module));
+
+ std::vector<std::unique_ptr<GlobalData>> arguments;
+ if (use_fake_data) {
+ arguments = MakeFakeArgumentsOrDie(computation, client);
+ } else { // use recorded data if available
+ for (const Literal& literal : module.arguments()) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<GlobalData> data,
+ client->TransferToServer(literal));
+ arguments.push_back(std::move(data));
+ }
+ }
+
+ std::vector<GlobalData*> execute_arguments;
+ for (auto& argument : arguments) {
+ execute_arguments.push_back(argument.get());
+ }
+ return client->ExecuteAndTransfer(computation, execute_arguments);
+}
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args, bool use_fake_data) {
+ Client* client = ClientLibrary::LocalClientOrDie();
+ tensorflow::Env* env = tensorflow::Env::Default();
+ for (char* arg : args) {
+ SessionModule module;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module));
+ StatusOr<std::unique_ptr<Literal>> result_status =
+ ReplayComputation(module, use_fake_data, client);
+ if (!result_status.ok()) {
+ fprintf(stderr, "%s: error: %s\n", arg,
+ result_status.status().ToString().c_str());
+ continue;
+ }
+ std::unique_ptr<Literal> result = result_status.ConsumeValueOrDie();
+ fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(),
+ ShapeUtil::HumanString(result->shape()).c_str(),
+ LiteralUtil::ToString(*result).c_str());
+ if (module.has_result()) {
+ fprintf(stdout, "was %s:%s\n",
+ ShapeUtil::HumanString(module.result().shape()).c_str(),
+ LiteralUtil::ToString(module.result()).c_str());
+ }
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ // Flags
+ bool use_fake_data = false;
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("use_fake_data", &use_fake_data,
+ "Replay computation using fake data"),
+ };
+ xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc < 2 || !parse_ok) {
+ LOG(QFATAL) << usage;
+ }
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args, use_fake_data);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc
new file mode 100644
index 0000000000..cf363913b1
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/show_literal.cc
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: show_literal <path-to-serialized-literal-proto>
+//
+// Dumps out the Literal::ToString of a tensorflow::WriteBinaryProto format
+// Literal serialized on disk.
+
+#include <stdio.h>
+#include <string>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+int main(int argc, char **argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ if (argc < 2) {
+ LOG(QFATAL) << "Usage: " << argv[0]
+ << " <path-to-serialized-literal-proto>";
+ }
+
+ xla::Literal literal;
+ TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1],
+ &literal));
+ LOG(INFO) << "literal: " << literal.ShortDebugString();
+ fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str());
+}
diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc
new file mode 100644
index 0000000000..1f3340cbc6
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/show_signature.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: show_signature some_binary_snapshot_proto*
+//
+// Shows the signature (ProgramShape) of binary snapshot proto(s) on the command
+// line.
+//
+// some_binary_snapshot_proto is obtained by serializing the SessionModule from
+// ServiceInterface::SnapshotComputation to disk.
+//
+// The output format is:
+//
+// file_path: computation_name :: program_shape_str
+
+#include <stdio.h>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/computation.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+namespace tools {
+
+void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
+ Client* client = ClientLibrary::LocalClientOrDie();
+ for (char* arg : args) {
+ SessionModule module;
+ TF_CHECK_OK(
+ tensorflow::ReadBinaryProto(tensorflow::Env::Default(), arg, &module));
+ Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
+ std::unique_ptr<ProgramShape> shape =
+ client->GetComputationShape(computation).ConsumeValueOrDie();
+ fprintf(stdout, "%s: %s :: %s\n", arg, module.entry().name().c_str(),
+ ShapeUtil::HumanString(*shape).c_str());
+ }
+}
+
+} // namespace tools
+} // namespace xla
+
+int main(int argc, char** argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ tensorflow::gtl::ArraySlice<char*> args(argv, argc);
+ args.pop_front(); // Pop off the binary name, argv[0]
+ xla::tools::RealMain(args);
+ return 0;
+}
diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc
new file mode 100644
index 0000000000..2d983b407c
--- /dev/null
+++ b/tensorflow/compiler/xla/tools/show_text_literal.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Usage: show_text_literal <path-to-serialized-literal-text>
+
+#include <stdio.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/text_literal_reader.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+int main(int argc, char **argv) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ if (argc < 2) {
+ LOG(QFATAL) << "Usage: " << argv[0] << " <path-to-serialized-literal-text>";
+ }
+
+ std::unique_ptr<xla::Literal> literal =
+ xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie();
+
+ LOG(INFO) << "literal: " << literal->ShortDebugString();
+ fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(*literal).c_str());
+ if (literal->shape().element_type() == xla::F32) {
+ float min =
+ *std::min_element(literal->f32s().begin(), literal->f32s().end());
+ float max =
+ *std::max_element(literal->f32s().begin(), literal->f32s().end());
+ fprintf(stderr, "min: %a=%f\n", min, min);
+ fprintf(stderr, "max: %a=%f\n", max, max);
+ }
+}
diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h
new file mode 100644
index 0000000000..8258031a2c
--- /dev/null
+++ b/tensorflow/compiler/xla/types.h
@@ -0,0 +1,37 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_
+#define TENSORFLOW_COMPILER_XLA_TYPES_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+using ::tensorflow::string;
+
+using ::tensorflow::int8;
+using ::tensorflow::int16;
+using ::tensorflow::int32;
+using ::tensorflow::int64;
+
+using ::tensorflow::uint8;
+using ::tensorflow::uint16;
+using ::tensorflow::uint32;
+using ::tensorflow::uint64;
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_TYPES_H_
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
new file mode 100644
index 0000000000..d23002c1a0
--- /dev/null
+++ b/tensorflow/compiler/xla/util.cc
@@ -0,0 +1,238 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/util.h"
+
+#include <stdarg.h>
+
+#include "tensorflow/compiler/xla/legacy_flags/util_flags.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/stacktrace.h"
+
+namespace xla {
+namespace {
+
+// Adds a backtrace to the provided status iff the xla_status_add_backtrace flag
+// is set. This is useful for quickly tracing status errors observed coming out
+// of the service.
+Status MaybeAddBacktrace(Status prior) {
+ DCHECK(!prior.ok());
+ if (legacy_flags::GetUtilFlags()->xla_status_add_backtrace) {
+ return Status{prior.code(),
+ tensorflow::strings::StrCat(prior.error_message(), " :: ",
+ tensorflow::CurrentStackTrace())};
+ } else {
+ return prior;
+ }
+}
+
+} // namespace
+
+ScopedLoggingTimer::ScopedLoggingTimer(const string& label, int32 vlog_level)
+ : label(label), vlog_level(vlog_level) {
+ if (VLOG_IS_ON(vlog_level)) {
+ start_micros = tensorflow::Env::Default()->NowMicros();
+ }
+}
+
+ScopedLoggingTimer::~ScopedLoggingTimer() {
+ if (VLOG_IS_ON(vlog_level)) {
+ uint64 end_micros = tensorflow::Env::Default()->NowMicros();
+ double secs = (end_micros - start_micros) / 1000000.0;
+
+ LOG(INFO) << label << " time: "
+ << tensorflow::strings::HumanReadableElapsedTime(secs);
+ }
+}
+
+Status AddStatus(Status prior, tensorflow::StringPiece context) {
+ CHECK(!prior.ok());
+ return Status{prior.code(), tensorflow::strings::StrCat(
+ context, ": ", prior.error_message())};
+}
+
+Status AppendStatus(Status prior, tensorflow::StringPiece context) {
+ CHECK(!prior.ok());
+ return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(),
+ ": ", context)};
+}
+
+// Implementation note: we can't common these out (without using macros) because
+// they all need to va_start/va_end their varargs in their frame.
+
+Status InvalidArgument(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::InvalidArgument(message));
+}
+
+Status Unimplemented(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::Unimplemented(message));
+}
+
+Status InternalError(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::Internal(message));
+}
+
+Status FailedPrecondition(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::FailedPrecondition(message));
+}
+
+Status ResourceExhausted(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::ResourceExhausted(message));
+}
+
+Status NotFound(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::NotFound(message));
+}
+
+Status Unavailable(const char* format, ...) {
+ string message;
+ va_list args;
+ va_start(args, format);
+ tensorflow::strings::Appendv(&message, format, args);
+ va_end(args);
+ return MaybeAddBacktrace(tensorflow::errors::Unavailable(message));
+}
+
+string Reindent(tensorflow::StringPiece original,
+ const tensorflow::StringPiece indentation) {
+ std::vector<string> pieces = tensorflow::str_util::Split(
+ tensorflow::StringPiece(original.data(), original.size()), '\n');
+ return tensorflow::str_util::Join(
+ pieces, "\n", [indentation](string* out, string s) {
+ tensorflow::StringPiece piece(s);
+ tensorflow::str_util::RemoveWhitespaceContext(&piece);
+ tensorflow::strings::StrAppend(out, indentation, piece);
+ });
+}
+
+std::vector<int64> InversePermutation(
+ tensorflow::gtl::ArraySlice<int64> input_permutation) {
+ std::vector<int64> output_permutation(input_permutation.size(), -1);
+ for (size_t i = 0; i < input_permutation.size(); ++i) {
+ output_permutation[input_permutation[i]] = i;
+ }
+ DCHECK_EQ(
+ 0, std::count(output_permutation.begin(), output_permutation.end(), -1));
+ DCHECK(std::is_permutation(input_permutation.begin(), input_permutation.end(),
+ output_permutation.begin()));
+ return output_permutation;
+}
+
+std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
+ tensorflow::gtl::ArraySlice<int64> p2) {
+ CHECK_EQ(p1.size(), p2.size());
+ std::vector<int64> output;
+ for (size_t i = 0; i < p1.size(); ++i) {
+ output.push_back(p1[p2[i]]);
+ }
+ return output;
+}
+
+int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
+ int64 value) {
+ return std::find(container.begin(), container.end(), value) -
+ container.begin();
+}
+
+PaddingConfig MakeNoPaddingConfig(int64 rank) {
+ PaddingConfig padding_config;
+ for (int64 dnum = 0; dnum < rank; ++dnum) {
+ auto dimension = padding_config.add_dimensions();
+ dimension->set_edge_padding_low(0);
+ dimension->set_edge_padding_high(0);
+ dimension->set_interior_padding(0);
+ }
+ return padding_config;
+}
+
+string HumanReadableNumFlops(double flops, double nanoseconds) {
+ if (nanoseconds == 0) {
+ return "NaN FLOP/s";
+ }
+ double nano_flops = flops / nanoseconds;
+ string throughput = tensorflow::strings::HumanReadableNum(
+ static_cast<int64>(nano_flops * 1e9));
+ tensorflow::StringPiece sp(throughput);
+ // Use the more common "G(FLOPS)", rather than "B(FLOPS)"
+ if (sp.ends_with("B") || // Ends in 'B', ignoring case
+ sp.ends_with("b")) {
+ *throughput.rbegin() = 'G';
+ }
+ throughput += "FLOP/s";
+ return throughput;
+}
+
+void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
+ int lineno) {
+ const int orig_sev = sev;
+ if (sev == tensorflow::FATAL) {
+ sev = tensorflow::ERROR;
+ }
+
+ size_t cur = 0;
+ while (cur < text.size()) {
+ size_t eol = text.find('\n', cur);
+ if (eol == tensorflow::StringPiece::npos) {
+ eol = text.size();
+ }
+ auto msg = text.substr(cur, eol - cur);
+ tensorflow::internal::LogString(fname, lineno, sev,
+ string(msg.data(), msg.size()));
+ cur = eol + 1;
+ }
+
+ if (orig_sev == tensorflow::FATAL) {
+ tensorflow::internal::LogString(fname, lineno, orig_sev,
+ "Aborting due to errors.");
+ }
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h
new file mode 100644
index 0000000000..137c613e6f
--- /dev/null
+++ b/tensorflow/compiler/xla/util.h
@@ -0,0 +1,257 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Generally useful utility functions that are common to (not specific to any
+// given part of) the XLA code base.
+
+#ifndef TENSORFLOW_COMPILER_XLA_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_UTIL_H_
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/xla/status.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/math/math_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+
+// RAII timer that logs with a given label the wall clock time duration in human
+// readable form. This differs from base's ElapsedTimer primarily in that it
+// spits out the human-readable duration form.
+struct ScopedLoggingTimer {
+ explicit ScopedLoggingTimer(const string& label, int32 vlog_level = 1);
+ ~ScopedLoggingTimer();
+
+ uint64 start_micros;
+ string label;
+ int32 vlog_level;
+};
+
+// Given a vector<T>, returns a MutableArraySlice<char> that points at its
+// internals.
+//
+// Warning: if the vector is updated its storage pointer may change, so use this
+// with caution (ideally in limited scopes with temporary lifetimes).
+template <typename T>
+tensorflow::gtl::MutableArraySlice<uint8> MutableByteSlice(std::vector<T>* v) {
+ return tensorflow::gtl::MutableArraySlice<uint8>(
+ reinterpret_cast<uint8*>(v->data()), v->size() * sizeof(T));
+}
+
+// Turns an immutable slice of type T into an immutable slice of bytes with the
+// same byte size.
+template <typename T>
+tensorflow::gtl::ArraySlice<uint8> CastToByteSlice(
+ tensorflow::gtl::ArraySlice<T> slice) {
+ return tensorflow::gtl::ArraySlice<uint8>(
+ reinterpret_cast<const uint8*>(slice.data()), slice.size() * sizeof(T));
+}
+
+// Casts a byte slice to a non-byte type T, checking that the original slice
+// length is a multiple of sizeof(T).
+template <typename T>
+tensorflow::gtl::ArraySlice<T> CastByteSlice(
+ tensorflow::gtl::ArraySlice<uint8> slice) {
+ CHECK_EQ(0, slice.size() % sizeof(T));
+ return tensorflow::gtl::ArraySlice<T>(
+ reinterpret_cast<const T*>(slice.data()), slice.size() / sizeof(T));
+}
+
+// Convenience function to force a vector to convert to an immutable slice.
+template <typename T>
+tensorflow::gtl::ArraySlice<T> AsSlice(const std::vector<T>& v) {
+ return tensorflow::gtl::ArraySlice<T>(v);
+}
+
+// Converts a mutable vector pointer into a MutableArraySlice of the same
+// type.
+template <typename T>
+tensorflow::gtl::MutableArraySlice<T> AsMutableSlice(std::vector<T>* v) {
+ return tensorflow::gtl::MutableArraySlice<T>(v->data(), v->size());
+}
+
+// xla::int64 is not the same type as tensorflow::protobuf_int64 in open-source.
+// Wrapper function that gives an int64 array slice view of a repeated int64
+// protobuf field.
+static inline tensorflow::gtl::ArraySlice<int64> AsInt64Slice(
+ const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& v) {
+ tensorflow::gtl::ArraySlice<tensorflow::protobuf_int64> slice(v);
+ return tensorflow::gtl::ArraySlice<int64>(
+ reinterpret_cast<const int64*>(slice.data()), slice.size());
+}
+
+// As above, but for uint64 types.
+static inline tensorflow::gtl::ArraySlice<uint64> AsUInt64Slice(
+ const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {
+ tensorflow::gtl::ArraySlice<tensorflow::protobuf_uint64> slice(v);
+ return tensorflow::gtl::ArraySlice<uint64>(
+ reinterpret_cast<const uint64*>(slice.data()), slice.size());
+}
+
+// Compares two containers for equality. Returns true iff the two containers
+// have the same size and all their elements compare equal using their
+// operator==. Like std::equal, but forces size equality.
+template <typename Container1T, typename Container2T>
+bool ContainersEqual(const Container1T& c1, const Container2T& c2) {
+ return ((c1.size() == c2.size()) &&
+ std::equal(std::begin(c1), std::end(c1), std::begin(c2)));
+}
+
+// Compares two containers for equality. Returns true iff the two containers
+// have the same size and all their elements compare equal using the predicate
+// p. Like std::equal, but forces size equality.
+template <typename Container1T, typename Container2T, class PredicateT>
+bool ContainersEqual(const Container1T& c1, const Container2T& c2,
+ PredicateT p) {
+ return ((c1.size() == c2.size()) &&
+ std::equal(std::begin(c1), std::end(c1), std::begin(c2), p));
+}
+
+// Adds some context information to the error message in a
+// Status. This is useful as Statuses are
+// propagated upwards.
+Status AddStatus(Status prior, tensorflow::StringPiece context);
+Status AppendStatus(Status prior, tensorflow::StringPiece context);
+
+// Status error shorthands -- printfs the arguments to be
+// used as an error message and returns a status in the canonical
+// error space.
+Status InvalidArgument(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+Status Unimplemented(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+Status InternalError(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+Status FailedPrecondition(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2);
+
+// Splits the lines of the original, replaces leading whitespace with the prefix
+// given by "indentation", and returns the string joined by newlines again. As a
+// side effect, any additional trailing whitespace is removed.
+//
+// Note: even different amounts of leading whitespace on different lines will be
+// uniformly replaced with "indentation".
+string Reindent(tensorflow::StringPiece original,
+ tensorflow::StringPiece indentation);
+
+// Applies `permutation` on `input` and returns the permuted array.
+// For each i, output[permutation[i]] = input[i].
+//
+// Precondition:
+// 1. `permutation` is a permutation of 0..permutation.size()-1.
+// 2. permutation.size() == input.size().
+template <template <typename...> class C, typename T>
+std::vector<T> Permute(tensorflow::gtl::ArraySlice<int64> permutation,
+ C<T> input_) {
+ tensorflow::gtl::ArraySlice<T> input(input_);
+ CHECK_EQ(permutation.size(), input.size());
+ std::vector<T> output(input.size());
+ for (size_t i = 0; i < permutation.size(); ++i) {
+ output[permutation[i]] = input[i];
+ }
+ DCHECK(std::is_permutation(input.begin(), input.end(), output.begin()));
+ return output;
+}
+
+// Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i.
+std::vector<int64> InversePermutation(
+ tensorflow::gtl::ArraySlice<int64> input_permutation);
+
+// Composes two permutations: output[i] = p1[p2[i]].
+std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
+ tensorflow::gtl::ArraySlice<int64> p2);
+
+int64 PositionInContainer(tensorflow::gtl::ArraySlice<int64> container,
+ int64 value);
+
+// Returns a PaddingConfig object that represents no padding for the given rank.
+PaddingConfig MakeNoPaddingConfig(int64 rank);
+
+// Imports the templated FloorOfRatio math function from the TensorFlow
+// namespace, as it is very commonly used.
+template <typename T>
+T FloorOfRatio(T dividend, T divisor) {
+ return tensorflow::MathUtil::FloorOfRatio<T>(dividend, divisor);
+}
+
+// Imports the templated CeilOfRatio math function from the TensorFlow
+// namespace, as it is very commonly used.
+template <typename T>
+T CeilOfRatio(T dividend, T divisor) {
+ return tensorflow::MathUtil::CeilOfRatio<T>(dividend, divisor);
+}
+
+// Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
+// then multiplying by the divisor. For example: RoundUpToMultiple(13, 8) => 16
+template <typename T>
+T RoundUpToNearest(T value, T divisor) {
+ return CeilOfRatio(value, divisor) * divisor;
+}
+
+// Given a number of flops executed in an amount of time, produces a string that
+// represents the throughput;
+// e.g. HumanReadableNumFlops(1e9, 1e9) => 1.00GFLOP/s.
+string HumanReadableNumFlops(double flops, double nanoseconds);
+
+// Split the text into multiple lines and log each line with the given
+// severity, filename, and line number.
+void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
+ int lineno);
+
+template <typename T>
+inline bool IsPowerOfTwo(T x) {
+ static_assert(!std::numeric_limits<T>::is_signed, "unsigned types only");
+ return x != 0 && (x & (x - 1)) == 0;
+}
+
+// Returns a mask with "bits" number of least significant bits set.
+inline uint32 LsbMaskU32(int bits) {
+ CHECK_GE(bits, 0);
+ return (1U << bits) - 1;
+}
+
+// Utility for performing a static_cast<> on a std::unique_ptr<>.
+template <typename Derived, typename Base>
+std::unique_ptr<Derived> unique_ptr_static_cast(std::unique_ptr<Base> ptr) {
+ return std::unique_ptr<Derived>(static_cast<Derived*>(ptr.release()));
+}
+
+} // namespace xla
+
+#define XLA_LOG_LINES(SEV, STRING) LogLines(SEV, STRING, __FILE__, __LINE__)
+
+#define XLA_VLOG_LINES(LEVEL, STRING) \
+ do { \
+ if (VLOG_IS_ON(LEVEL)) XLA_LOG_LINES(tensorflow::INFO, STRING); \
+ } while (false);
+
+// Utility macro that performs the equivalent of what one would expect
+// LOG_LINES(FATAL, X) to do but can be used at the end of a function that
+// returns a value without getting a compiler warning that no value is returned.
+#define XLA_FATAL_LOG(X) \
+ XLA_LOG_LINES(tensorflow::ERROR, X); \
+ LOG(FATAL) << "Aborting in " << __FUNCTION__ << " due to previous errors.";
+
+#endif // TENSORFLOW_COMPILER_XLA_UTIL_H_
diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc
new file mode 100644
index 0000000000..0faf3e6ecc
--- /dev/null
+++ b/tensorflow/compiler/xla/util_test.cc
@@ -0,0 +1,89 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/util.h"
+
+#include <list>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace xla {
+namespace {
+
+// Verifies that, even with a different number of leading spaces, the
+// Reindent routine turns them into a uniform number of leading spaces.
+//
+// Also throws in some trailing whitespace on the original to show it is
+// removed.
+TEST(UtilTest, ReindentsDifferentNumberOfLeadingSpacesUniformly) {
+ string original = R"( hello there
+ world)";
+ string got = Reindent(original, " ");
+ string want = R"( hello there
+ world)";
+ EXPECT_EQ(want, got);
+}
+
+// Some smoke tests for ContainersEqual. Keeping it simple since these are just
+// basic wrappers around std::equal.
+TEST(UtilTest, ContainersEqualDefault) {
+ std::vector<int> c1 = {1, 2, 3, 4};
+ std::vector<int> c2 = {1, 2, 3};
+ std::vector<int> c3 = {};
+ std::vector<int> c4 = {1, 2, 3, 4};
+ std::vector<int> c5 = {1, 2, 3, 4, 5};
+ std::vector<int> c6 = {1, 3, 4, 5};
+
+ EXPECT_TRUE(ContainersEqual(c1, c4));
+ EXPECT_TRUE(ContainersEqual(c4, c1));
+ EXPECT_FALSE(ContainersEqual(c1, c2));
+ EXPECT_FALSE(ContainersEqual(c2, c1));
+ EXPECT_FALSE(ContainersEqual(c1, c3));
+ EXPECT_FALSE(ContainersEqual(c3, c1));
+ EXPECT_FALSE(ContainersEqual(c1, c5));
+ EXPECT_FALSE(ContainersEqual(c5, c1));
+ EXPECT_FALSE(ContainersEqual(c1, c6));
+ EXPECT_FALSE(ContainersEqual(c6, c1));
+}
+
+TEST(UtilTest, ContainersEqualPredicate) {
+ std::vector<int> c1 = {1, 2, 3, 4};
+ std::vector<int> c2 = {10, 20, 30, 40};
+
+ EXPECT_TRUE(ContainersEqual(
+ c1, c2, [](const int& i1, const int& i2) { return i1 < i2; }));
+ EXPECT_FALSE(ContainersEqual(
+ c1, c2, [](const int& i1, const int& i2) { return i1 > i2; }));
+}
+
+TEST(UtilTest, ContainersEqualDifferentContainerTypes) {
+ std::vector<int> c1 = {1, 2, 3, 4};
+ std::list<int> c2 = {1, 2, 3, 4};
+
+ EXPECT_TRUE(ContainersEqual(c1, c2));
+}
+
+TEST(UtilTest, HumanReadableNumFlopsExample) {
+ ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
+}
+
+TEST(UtilTest, LogLines) {
+ // Just make sure this code runs (not verifying the output).
+ LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
new file mode 100644
index 0000000000..98c6b08e44
--- /dev/null
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/window_util.h"
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace xla {
+namespace window_util {
+
+/* static */ string ToString(const WindowDimension& dim) {
+ using tensorflow::strings::StrCat;
+ using tensorflow::strings::StrAppend;
+ string str = StrCat("(size=", dim.size());
+ if (dim.stride() != 1) {
+ StrAppend(&str, ",stride=", dim.stride());
+ }
+ if (dim.padding_low() != 0) {
+ StrAppend(&str, ",padding_low=", dim.padding_low());
+ }
+ if (dim.padding_high() != 0) {
+ StrAppend(&str, ",padding_high=", dim.padding_high());
+ }
+ if (dim.base_dilation() != 1) {
+ StrAppend(&str, ",base_dilation=", dim.base_dilation());
+ }
+ if (dim.window_dilation() != 1) {
+ StrAppend(&str, ",window_dilation=", dim.window_dilation());
+ }
+ StrAppend(&str, ")");
+ return str;
+}
+
+string ToString(const Window& window) {
+ std::vector<string> window_dimension_strings;
+ for (const auto& window_dimension : window.dimensions()) {
+ window_dimension_strings.push_back(ToString(window_dimension));
+ }
+ return "{" + tensorflow::str_util::Join(window_dimension_strings, ", ") + "}";
+}
+
+bool HasStride(const Window& window) {
+ for (const auto& dim : window.dimensions()) {
+ if (dim.stride() != 1) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasPadding(const Window& window) {
+ for (const auto& dim : window.dimensions()) {
+ if (dim.padding_low() != 0 || dim.padding_high() != 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasEvenPadding(const Window& window) {
+ return std::all_of(window.dimensions().begin(), window.dimensions().end(),
+ [](const WindowDimension& dim) {
+ return dim.padding_low() == dim.padding_high();
+ });
+}
+
+bool HasNegativePadding(const Window& window) {
+ return std::any_of(window.dimensions().begin(), window.dimensions().end(),
+ [](const WindowDimension& dim) {
+ return dim.padding_low() < 0 || dim.padding_high() < 0;
+ });
+}
+
+bool HasBaseDilation(const Window& window) {
+ for (const auto& dim : window.dimensions()) {
+ if (dim.base_dilation() != 1) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasWindowDilation(const Window& window) {
+ for (const auto& dim : window.dimensions()) {
+ if (dim.window_dilation() != 1) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasDilation(const Window& window) {
+ return HasBaseDilation(window) || HasWindowDilation(window);
+}
+
+int64 DilatedBound(int64 bound, int64 dilation) {
+ CHECK_GE(bound, 0);
+ CHECK_GE(dilation, 1);
+
+ // Suppose the array has three entries 123 and the dilation factor is 4. Then
+ // the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
+ // the last expands into 4 entries, so that is (bound - 1) * dilation. Then we
+ // add 1 to account for the final input element.
+ return (bound - 1) * dilation + 1;
+}
+
+int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
+ CHECK_GE(window_size, 0);
+ CHECK_GE(bound, 0);
+ CHECK_GE(stride, 1);
+
+ if (window_size > bound) {
+ return 0;
+ }
+
+ // Without considering stride, the maximum valid offset is bound -
+ // window_size. Taking stride into account, the valid offsets then have the
+ // form q * stride for q = 0, ..., Q such that q * stride <= bound -
+ // window_size. This implies that Q equals floor(bound - window_size /
+ // stride). There are Q + 1 valid values of q, yielding the formula below.
+ return (bound - window_size) / stride + 1;
+}
+
+} // namespace window_util
+} // namespace xla
diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h
new file mode 100644
index 0000000000..235cb2d59d
--- /dev/null
+++ b/tensorflow/compiler/xla/window_util.h
@@ -0,0 +1,66 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
+#define TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace window_util {
+
+string ToString(const WindowDimension& dim);
+string ToString(const Window& window);
+
+// The below functions return true if the given field is set to have a
+// non-trivial effect, e.g. having a stride means that the stride of some
+// dimension is not one. Whether the proto field is populated is not a
+// consideration.
+
+bool HasStride(const Window& window);
+bool HasPadding(const Window& window);
+bool HasEvenPadding(const Window& window);
+bool HasNegativePadding(const Window& window);
+
+bool HasBaseDilation(const Window& window);
+bool HasWindowDilation(const Window& window);
+bool HasDilation(const Window& window);
+
+// Returns the new bound after dilation.
+//
+// If a window with the given bound in some dimension is dilated with the given
+// dilation factor in that dimension, then the value returned is the bound for
+// the array in that dimension after dilation.
+//
+// For a 1D array with 3 entries 1, 2, 3, a dilation factor of 2 yields a new
+// window with values 1, x, 2, x, 3, where x indicates holes left by the
+// dilation. So DilatedBound(3, 2) == 5.
+int64 DilatedBound(int64 bound, int64 dilation);
+
+// Returns the number of valid positions of a window with the given size and
+// stride within an array with the given bound. This is the bound of an output
+// array with one element per valid position of the window.
+//
+// For example, for arguments of (bound=5, window_size=2, stride=2), the
+// returned value is 2. There are valid positions at offset 0 and offset 2,
+// while offset 4 is not valid since the window's last entry would be at 5,
+// which is beyond the bound of 5.
+int64 StridedBound(int64 bound, int64 window_size, int64 stride);
+
+} // namespace window_util
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_WINDOW_UTIL_H_
diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl
new file mode 100644
index 0000000000..bdd3dfe82d
--- /dev/null
+++ b/tensorflow/compiler/xla/xla.bzl
@@ -0,0 +1,22 @@
+"""Wrapper around cc_proto_library used inside the XLA codebase."""
+
+load("@protobuf//:protobuf.bzl", "cc_proto_library")
+
+# xla_proto_library() is a convenience wrapper around cc_proto_library.
+def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0):
+ cc_proto_library(name=name,
+ srcs=srcs,
+ deps=deps,
+ cc_libs = ["@protobuf//:protobuf"],
+ protoc="@protobuf//:protoc",
+ default_runtime="@protobuf//:protobuf",
+ testonly=testonly,
+ visibility=visibility,)
+
+# Flags required for modules that export symbols that are to be called by the
+# XLA CustomCall operator. CustomCall must be able to find symbols with dlsym(),
+# which on Linux requires we link with --export-dynamic.
+export_dynamic_linkopts = select({
+ "//tensorflow:darwin": [],
+ "//conditions:default": ["-Wl,--export-dynamic"],
+})
diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto
new file mode 100644
index 0000000000..5082a1c9a7
--- /dev/null
+++ b/tensorflow/compiler/xla/xla.proto
@@ -0,0 +1,291 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+import "tensorflow/compiler/xla/xla_data.proto";
+import "tensorflow/compiler/xla/service/session.proto";
+
+package xla;
+
+message SnapshotComputationRequest {
+ ComputationHandle computation = 1;
+}
+
+message SnapshotComputationResponse {
+ SessionModule module = 1;
+}
+
+message LoadComputationSnapshotRequest {
+ SessionModule module = 1;
+}
+
+message LoadComputationSnapshotResponse {
+ ComputationHandle computation = 1;
+}
+
+message GetDeviceHandlesRequest {
+ int64 device_count = 1;
+}
+
+message GetDeviceHandlesResponse {
+ repeated DeviceHandle device_handles = 1;
+}
+
+message TransferToClientRequest {
+ GlobalDataHandle data = 1;
+
+ // This optional field directs the service to return the literal in this
+ // layout. A shape is used to hold the layout to accomodate tuples.
+ Shape shape_with_layout = 2;
+}
+
+message TransferToClientResponse {
+ Literal literal = 1;
+}
+
+message TransferToServerRequest {
+ Literal literal = 1;
+ DeviceHandle device_handle = 2;
+}
+
+message TransferToServerResponse {
+ GlobalDataHandle data = 1;
+}
+
+message TransferToServerInProcessRequest {
+ uint64 buffer = 1;
+ Shape shape = 2;
+}
+
+message TransferToServerInProcessResponse {
+ GlobalDataHandle data = 1;
+}
+
+message TransferToClientInProcessRequest {
+ GlobalDataHandle data = 1;
+ uint64 buffer = 2;
+}
+
+message TransferToClientInProcessResponse {
+}
+
+message TransferToInfeedRequest {
+ Literal literal = 1;
+ int64 replica_id = 2;
+ DeviceHandle device_handle = 3;
+}
+
+message TransferToInfeedResponse {
+}
+
+message ResetDeviceRequest {
+ DeviceHandle device_handle = 1;
+}
+
+message ResetDeviceResponse {
+}
+
+message ComputationStatsRequest {
+ ComputationHandle computation = 1;
+}
+
+message ComputationStatsResponse {
+ ComputationStats stats = 1;
+}
+
+message ComputationRequest {
+ string name = 1;
+}
+
+message ComputationResponse {
+ ComputationHandle computation = 1;
+}
+
+message CreateChannelHandleRequest {
+}
+
+message CreateChannelHandleResponse {
+ ChannelHandle channel = 1;
+}
+
+message UnregisterRequest {
+ GlobalDataHandle data = 1;
+}
+
+message UnregisterResponse {
+}
+
+message SetReturnValueRequest {
+ ComputationHandle computation = 1;
+ ComputationDataHandle operand = 2;
+}
+
+message SetReturnValueResponse {
+}
+
+message ExecuteRequest {
+ ComputationHandle computation = 1;
+ repeated GlobalDataHandle arguments = 2;
+
+ // This optional field is a hint to the service to store the result of the
+ // computation with a particular layout. Subsequent transfers of the array to
+ // the client may be faster using this layout. A shape is used to hold the
+ // layout to accomodate computations which have tuple output.
+ Shape shape_with_output_layout = 3;
+
+ // This optional field specifies the graph-level seed. This value could be
+ // populated with a host-side generated random number to add some additional
+ // entropy to the device. A seed of 0 represents no seed set.
+ // TODO(b/32083678): This forces a recompilation.
+ uint64 seed = 4;
+
+ // This optional field specifies a particular device to run the computation.
+ // If not provided, the default device will be chosen.
+ DeviceHandle device_handle = 5;
+}
+
+message ExecuteParallelRequest {
+ repeated ExecuteRequest requests = 1;
+}
+
+message ExecuteResponse {
+ GlobalDataHandle output = 1;
+ ExecutionProfile profile = 2;
+}
+
+message ExecuteParallelResponse {
+ repeated ExecuteResponse responses = 1;
+}
+
+message ExecuteAsyncRequest {
+ ComputationHandle computation = 1;
+ repeated GlobalDataHandle arguments = 2;
+
+ // This optional field is a hint to the service to store the result of the
+ // computation with a particular layout. Subsequent transfers of the array to
+ // the client may be faster using this layout. A shape is used to hold the
+ // layout to accomodate computations which have tuple output.
+ Shape shape_with_output_layout = 3;
+
+ // This optional field specifies the graph-level seed (go/tla-rng). This value
+ // could be populated with a host-side generated random number to add some
+ // additional entropy to the device.
+ // TODO(b/32083678): This forces a recompilation.
+ uint64 seed = 4;
+}
+
+message ExecuteAsyncResponse {
+ // A handle to the execution launched asynchronously.
+ ExecutionHandle execution = 1;
+}
+
+message WaitForExecutionRequest {
+ ExecutionHandle execution = 1;
+}
+
+message WaitForExecutionResponse {
+ GlobalDataHandle output = 1;
+ ExecutionProfile profile = 2;
+}
+
+message IsConstantRequest {
+ ComputationHandle computation = 1;
+ ComputationDataHandle operand = 2;
+}
+
+message IsConstantResponse {
+ bool is_constant = 1;
+}
+
+message ComputeConstantRequest {
+ ComputationHandle computation = 1;
+ ComputationDataHandle operand = 2;
+ Layout output_layout = 3;
+}
+
+message ComputeConstantResponse {
+ GlobalDataHandle output = 1;
+}
+
+message DeconstructTupleRequest {
+ GlobalDataHandle tuple_handle = 2;
+}
+
+message DeconstructTupleResponse {
+ repeated GlobalDataHandle element_handles = 1;
+}
+
+message LoadDataRequest {
+ // Describes the path of the ColumnIO tablet to load.
+ string columnio_tablet_path = 1;
+
+ // Describes the field to load within the ColumnIO tablet.
+ string columnio_field = 2;
+
+ // Individual element shape, excluding rows.
+ Shape element_shape = 3;
+
+ // Warning: ColumnIO does not support random-access, so use offset with
+ // caution in performance-critical scenarios.
+ int64 offset = 4;
+
+ // Maximum number of elements (with shape element_shape) to load.
+ int64 limit = 5;
+
+ // If more than one item is requested (via limit > 1), then this request
+ // attribute zips together the produced vectors.
+ bool zip = 6;
+}
+
+message LoadDataResponse {
+ GlobalDataHandle data = 1;
+ Shape data_shape = 2;
+ int64 available_rows = 3;
+ int64 rows_loaded = 4;
+ int64 nanoseconds = 5;
+}
+
+message SpecializeRequest {
+ ComputationHandle computation = 1;
+ repeated GlobalDataHandle arguments = 2;
+}
+
+message SpecializeResponse {
+}
+
+message GetShapeRequest {
+ GlobalDataHandle data = 1;
+}
+
+message GetShapeResponse {
+ Shape shape = 1;
+}
+
+message GetComputationShapeRequest {
+ ComputationHandle computation = 1;
+}
+
+message GetComputationShapeResponse {
+ ProgramShape program_shape = 1;
+}
+
+message UnpackRequest {
+ GlobalDataHandle data = 1;
+}
+
+message UnpackResponse {
+ repeated GlobalDataHandle tied_data = 1;
+}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
new file mode 100644
index 0000000000..4a19d86e77
--- /dev/null
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -0,0 +1,714 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package xla;
+
+// Primitive types are the individual values that can be held in rectangular
+// multidimensional arrays. A description of the rectangular multidimensional
+// array dimensions / primitive type is given by Shape, below.
+enum PrimitiveType {
+ // Invalid primitive type to serve as default.
+ PRIMITIVE_TYPE_INVALID = 0;
+
+ // Predicates are two-state booleans.
+ PRED = 1;
+
+ // Signed integral values of fixed width.
+ S8 = 2;
+ S16 = 3;
+ S32 = 4;
+ S64 = 5;
+
+ // Unsigned integral values of fixed width.
+ U8 = 6;
+ U16 = 7;
+ U32 = 8;
+ U64 = 9;
+
+ // Floating-point values of fixed width.
+ //
+ // Note: if f16s are not natively supported on the device, they will be
+ // converted to f16 from f32 at arbirary points in the computation.
+ F16 = 10;
+ F32 = 11;
+ F64 = 12;
+
+ // A tuple is a polymorphic sequence; e.g. a shape that holds different
+ // sub-shapes. They are used for things like returning multiple values from a
+ // computation; e.g. a computation that returns weights and biases may have a
+ // signature that results in a tuple like (f32[784x2000], f32[2000])
+ //
+ // Tuples are currently special in that they may only be rank 0.
+ TUPLE = 13;
+
+ // An opaque type used for passing context specific data to a custom
+ // operation.
+ OPAQUE = 14;
+}
+
+// Describes the value held inside padding elements.
+enum PaddingValue {
+ INVALID_PAD = 0;
+
+ // Zero padding must be 0-values that correspond to the shape's element type.
+ ZERO_PAD = 1;
+
+ // One padding must be 1-values that correspond to the shape's element type.
+ ONE_PAD = 2;
+
+ // "Lowest" padding must be the lowest values in the shape's element type,
+ // used as padding for operations like max-accumulation.
+ LOWEST_PAD = 3;
+
+ // "Highest" padding must be the largest values in the shape's element type,
+ // used as padding for operations like min-accumulation.
+ HIGHEST_PAD = 4;
+
+ // Unknown padding could be anything; e.g. floating NaNs!
+ UNKNOWN_PAD = 5;
+}
+
+// Describes the padding configuration for Pad operation. The padding amount on
+// both edges as well as between the elements are specified for each dimension.
+message PaddingConfig {
+ // Describes the padding configuration for a dimension.
+ message PaddingConfigDimension {
+ // Padding amount on the low-end (next to the index 0).
+ int64 edge_padding_low = 1;
+
+ // Padding amount on the high-end (next to the highest index).
+ int64 edge_padding_high = 2;
+
+ // Padding amount between the elements.
+ int64 interior_padding = 3;
+ }
+
+ // The padding configuration for all dimensions.
+ repeated PaddingConfigDimension dimensions = 1;
+}
+
+// A layout describes how the array is placed in (1D) memory space. This
+// includes the minor-to-major ordering of dimensions within a shape, as well as
+// any padding present in those dimensions.
+//
+// Clients must specify the layouts of input Literals to the
+// computation. Layouts specified in interior operations which take Shapes (for
+// example, Convert) are ignored.
+//
+// See the XLA documentation for more information on shapes and layouts.
+message Layout {
+ // Sequence of dimension numbers, from minor (fastest varying index) to major
+ // (slowest varying index). This field is required.
+ repeated int64 minor_to_major = 1;
+
+ // The width to which the layout of each dimension is padded up
+ // to. If present, the size of the padded_dimensions must equal the
+ // rank of the shape. The padding appears at the end of a dimension,
+ // not at the beginning. This kind of padding, unlike padding in
+ // e.g. convolution, is not part of the shape.
+ repeated int64 padded_dimensions = 2;
+
+ // Describes the values in the padding specified by
+ // padded_dimensions.
+ PaddingValue padding_value = 3;
+
+ // Important: if any field is added, be sure to modify ShapeUtil::Equal()
+ // appropriately to account for the new field.
+}
+
+// A shape describes the number of dimensions in the array, the size of each
+// dimension, and the primitive component type.
+//
+// Tuples are a special case in that they have rank zero and have tuple_shapes
+// defined.
+//
+// See the XLA documentation for more information on shapes and layouts.
+message Shape {
+ reserved 1;
+ reserved "rank";
+
+ // The element type for this shape.
+ PrimitiveType element_type = 2;
+
+ // The size (number of elements) for each dimension.
+ // In XLA, dimensions are numbered from 0 to N-1 for an
+ // N-dimensional array. The first element of 'dimensions' is the size of
+ // dimension 0, the second element is the size of dimension 1, and so forth.
+ // Empty list indicates a scalar.
+ repeated int64 dimensions = 3;
+
+ // For tuples only, the shapes of constitutent shapes in the tuple sequence.
+ repeated Shape tuple_shapes = 4;
+
+ // The layout used to back this shape.
+ Layout layout = 5;
+
+ // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
+ // ShapeUtil::Compatible() appropriately to account for the new field.
+}
+
+// Shape of the parameters and output of a computation (like a traditional
+// function signature).
+message ProgramShape {
+ repeated Shape parameters = 1;
+ Shape result = 2;
+ repeated string parameter_names = 3;
+}
+
+// Statistics of a computation.
+message ComputationStats {
+ // The number of floating point operations in the computation.
+ double flop_count = 1;
+
+ // The number of transcendental operations (e.g., exp) in the computation.
+ double transcendental_count = 2;
+}
+
+// Profile data from the execution of a computation.
+message ExecutionProfile {
+ // Whether the executable was read from the compilation cache.
+ bool compilation_cache_hit = 1;
+
+ // The time in milliseconds spent to compile the computation. This only set if
+ // the executable was not read from the compilation cache
+ // (compilation_cache_hit == false).
+ int64 compile_time_ms = 2;
+
+ // The number of cycles spent for the computation. This does not include the
+ // time taken for the data transfers between the host and the device. This is
+ // a target-dependent field and only used for debugging purposes.
+ int64 compute_cycle_count = 3;
+
+ // The time in nanoseconds spent for the computation, without data transfer.
+ int64 compute_time_ns = 4;
+
+ // The time in nanoseconds spent for the entire computation, including the
+ // result data transfer time. Current implementation does not spend any cycles
+ // for the input data transfer since the memory is initialized with the proper
+ // values before the execution.
+ int64 compute_and_transfer_time_ns = 5;
+}
+
+// Handle given to a user that represents a computation that the user builds up
+// before execution.
+message ComputationHandle {
+ int64 handle = 1;
+}
+
+// Handle given to a user that represents an execution that the user launched
+// asynchronously on the device.
+message ExecutionHandle {
+ int64 handle = 1;
+}
+
+// Handle given to a user that represents a globally accessible allocation.
+// Contrast this against a ComputationDataHandle, which is not globally
+// accessible, since it only exists within a specific computation.
+message GlobalDataHandle {
+ int64 handle = 1;
+}
+
+// Handle given to a user that represents a data result in a computation.
+// This is used to pass to subsequent computations that depends upon the data as
+// an operand.
+message ComputationDataHandle {
+ int64 handle = 1;
+}
+
+// Handle given to a user that represents a device to execute a computation.
+// When replication is enabled, the device handle represents the device for the
+// replica id 0.
+message DeviceHandle {
+ int64 handle = 1;
+}
+
+// Handle given to a user to represent a channel between two computations
+// via a Send and Recv instruction pair. Channels are unbuffered, so Send
+// Send instructions will be blocked until the data is transferred.
+message ChannelHandle {
+ int64 handle = 1;
+}
+
+// Literals are used when the server and client need to exchange materialized
+// data / results. Literals are also used to describe constants used in
+// computations.
+//
+// Transfers to/from the client are encoded in literal form, and the structure
+// of the repeated fields is implied by the shape.
+message Literal {
+ Shape shape = 1;
+ repeated bool preds = 2;
+ bytes u8s = 3;
+ repeated int32 s32s = 4;
+ repeated int64 s64s = 5;
+ repeated uint32 u32s = 6;
+ repeated uint64 u64s = 7;
+ repeated float f32s = 8;
+ repeated double f64s = 9;
+ repeated Literal tuple_literals = 10;
+}
+
+message WindowDimension {
+ // The size of the window in this dimension. For a rectangle, this would be
+ // the width or height.
+ int64 size = 1;
+
+ // The stride at which the window moves across the base area in this
+ // dimension. In other words, this is the spacing between different
+ // positions of the window in this dimension.
+ int64 stride = 2;
+
+ // If positive, means the amount of padding with zeroes to add to the base
+ // area at the low end of this dimension; if negative, its negative means the
+ // number of elements removed from the low end of this dimension. For example,
+ // in the horizontal dimension of a rectangle, this would be the number of
+ // zeroes to pad on the left, given that indices increase when going right.
+ int64 padding_low = 3;
+
+ // As padding_low, but on the high end of this dimension. For
+ // example, in the horizontal dimension of a rectangle, this would
+ // be the number of zeroes to pad on the right, given that indices
+ // increase when going right.
+ int64 padding_high = 4;
+
+ // Dilation factor of the sliding window in this dimension. A dilation factor
+ // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
+ // implicitly placed between each kernel element. See documentation for
+ // convolution.
+ int64 window_dilation = 5;
+
+ // Dilation factor of the base area in this dimension. A dilation factor of 1
+ // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
+ // placed between each base area element. See documentation for convolution.
+ int64 base_dilation = 6;
+}
+
+// Describes the windowing in an operation such as convolution.
+//
+// The window is moved across a base area and for each position of the
+// window a computation is performed. The field below describes the
+// window and the movement of the window across a base area.
+message Window {
+ repeated WindowDimension dimensions = 1;
+}
+
+// Operation requests that are all collected as a tagged union with a oneof
+// field in OpRequest.
+
+message ConstantRequest {
+ Literal literal = 2;
+}
+
+message GetTupleElementRequest {
+ ComputationDataHandle operand = 2;
+ int64 index = 3;
+}
+
+message SliceRequest {
+ ComputationDataHandle operand = 2;
+ repeated int64 start_indices = 3;
+ repeated int64 limit_indices = 4;
+}
+
+message DynamicSliceRequest {
+ // Operand from which to slice at dynamic 'start_indices'.
+ ComputationDataHandle operand = 2;
+ // Dynamically computed 'start_indices' for slice operation.
+ ComputationDataHandle start_indices = 3;
+ // Slice sizes for each dimension (note that indices calculations are computed
+ // modulo dimension sizes to avoid out-of-bound array accesses).
+ repeated int64 slice_sizes = 4;
+}
+
+message DynamicUpdateSliceRequest {
+ // Operand on which slice 'update' is to be applied.
+ ComputationDataHandle operand = 2;
+ // The slice update to apply to 'operand'.
+ ComputationDataHandle update = 3;
+ // Dynamically computed start indices for the update slice operation.
+ ComputationDataHandle start_indices = 4;
+}
+
+message ConvolutionDimensionNumbers {
+ // The number of the dimension that represents batch in the input
+ // (lhs) and output.
+ int64 batch_dimension = 1;
+
+ // The number of the dimension that represents features in the input
+ // (lhs) and output.
+ int64 feature_dimension = 2;
+
+ // The dimension numbers for the spatial dimensions that the window
+ // moves through in the input (lhs) and output.
+ repeated int64 spatial_dimensions = 5;
+
+ // The number of the dimension that represents input features in the
+ // convolutional kernel (rhs).
+ int64 kernel_input_feature_dimension = 3;
+
+ // The number of the dimension that represents output features in
+ // the convolutional kernel (rhs).
+ int64 kernel_output_feature_dimension = 4;
+
+ // The dimension numbers for the spatial dimensions that the window
+ // moves through in the kernel (rhs). window.strides(0) is the
+ // stride in the kernel_spatial_dimensions(0) dimension.
+ repeated int64 kernel_spatial_dimensions = 6;
+};
+
+message ConvolveRequest {
+ ComputationDataHandle lhs = 2;
+ ComputationDataHandle rhs = 3; // This is the filter/kernel.
+ Window window = 4; // Describes the filter/kenel.
+ ConvolutionDimensionNumbers dimension_numbers = 5;
+}
+
+message InfeedRequest {
+ // The shape of the data returned by reading the device's infeed buffer.
+ Shape shape = 2;
+}
+
+message CallRequest {
+ ComputationHandle to_apply = 2;
+ repeated ComputationDataHandle operands = 3;
+}
+
+message CustomCallRequest {
+ string call_target_name = 2;
+ repeated ComputationDataHandle operands = 3;
+ Shape shape = 4;
+}
+
+message MapRequest {
+ repeated ComputationDataHandle operands = 2;
+ ComputationHandle to_apply = 3;
+ repeated ComputationDataHandle static_operands = 4;
+}
+
+message ReduceRequest {
+ // Operand to the reduction.
+ ComputationDataHandle operand = 2;
+
+ // Initial value for the reduction. This must be consistent with the result
+ // shape of to_apply.
+ ComputationDataHandle init_value = 3;
+
+ // The dimensions to reduce over.
+ repeated int64 dimensions = 4;
+
+ // The computation to apply in the reduction.
+ ComputationHandle to_apply = 5;
+}
+
+message ReduceWindowRequest {
+ ComputationDataHandle operand = 2;
+ ComputationDataHandle init_value = 3;
+ Window window = 4;
+ ComputationHandle to_apply = 5;
+}
+
+message CrossReplicaSumRequest {
+ ComputationDataHandle operand = 2;
+}
+
+message SelectAndScatterRequest {
+ // Operand array on which the windows slide.
+ ComputationDataHandle operand = 2;
+
+ // Source array for the data to scatter.
+ ComputationDataHandle source = 3;
+
+ // Initial scalar value for each element in the output.
+ ComputationDataHandle init_value = 4;
+
+ // Window configuration.
+ Window window = 5;
+
+ // Binary function used to select an element from each window.
+ ComputationHandle select = 6;
+
+ // Binary function used to combine each scattered value from source with the
+ // current output value at the selected location.
+ ComputationHandle scatter = 7;
+}
+
+message ReverseRequest {
+ ComputationDataHandle operand = 2;
+ repeated int64 dimensions = 3;
+}
+
+message BroadcastRequest {
+ ComputationDataHandle operand = 2;
+ repeated int64 broadcast_sizes = 3;
+}
+
+message PadRequest {
+ ComputationDataHandle operand = 2;
+ ComputationDataHandle padding_value = 3;
+ PaddingConfig padding_config = 4;
+}
+
+message ReshapeRequest {
+ ComputationDataHandle operand = 2;
+
+ // The dimension order for collapse (from fastest-changing to slowest).
+ repeated int64 dimensions = 3;
+
+ // The new dimension sizes (from dimension 0 to n-1).
+ repeated int64 new_sizes = 4;
+}
+
+message ParameterRequest {
+ Shape shape = 2;
+ int64 parameter = 3;
+ string name = 4;
+}
+
+message GetLocalShapeRequest {
+ ComputationHandle computation = 1;
+ ComputationDataHandle operand = 2;
+}
+
+message GetLocalShapeResponse {
+ Shape shape = 1;
+}
+
+message TraceRequest {
+ string tag = 2;
+ ComputationDataHandle operand = 3;
+}
+
+message ConvertRequest {
+ ComputationDataHandle operand = 2;
+ PrimitiveType new_element_type = 3;
+}
+
+message ConcatenateRequest {
+ repeated ComputationDataHandle operands = 2;
+ // The dimension in which we concatenate; e.g. if you had dimension arrays of
+ // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1].
+ // Attempting to concatenate those in dimension 1 would produce an error, as
+ // 4 != 5 (and there is no ragged array support).
+ int64 dimension = 3;
+}
+
+message WhileRequest {
+ ComputationHandle condition = 2;
+ ComputationHandle body = 3;
+ ComputationDataHandle init = 4;
+}
+
+enum UnaryOperation {
+ UNOP_INVALID = 0;
+
+ // Elementwise, logical negation
+ UNOP_LOGICAL_NOT = 1;
+
+ // Elementwise, computes e^x.
+ UNOP_EXP = 2;
+
+ // Elementwise, computes -x.
+ UNOP_NEGATE = 3;
+
+ // Puts the elements in the operand into sorted order.
+ UNOP_SORT = 4;
+
+ // Elementwise, computes tanh(x).
+ UNOP_TANH = 5;
+
+ // Elementwise, computes the natural logarithm of x.
+ UNOP_LOG = 6;
+
+ // Elementwise, computes the floor of x.
+ UNOP_FLOOR = 7;
+
+ // Elementwise, computes the ceil of x.
+ UNOP_CEIL = 8;
+
+ // Elementwise, computes the abs of x.
+ UNOP_ABS = 9;
+
+ // Elementwise, computes the sign of x.
+ UNOP_SIGN = 10;
+}
+
+message UnaryOpRequest {
+ UnaryOperation unop = 2;
+ ComputationDataHandle operand = 3;
+}
+
+enum BinaryOperation {
+ BINOP_INVALID = 0;
+
+ // Arithmetic operations.
+ BINOP_ADD = 1;
+ BINOP_DIV = 2;
+ BINOP_MUL = 3;
+ BINOP_SUB = 4;
+
+ // Comparison operators.
+ BINOP_EQ = 5;
+ BINOP_GE = 6;
+ BINOP_GT = 7;
+ BINOP_LE = 8;
+ BINOP_LT = 9;
+ BINOP_NE = 10;
+
+ // Dot product, matrix multiply.
+ BINOP_DOT = 12;
+
+ // Indexes into the LHS with the RHS.
+ //
+ // If the RHS is higher-rank, this is a gather operation.
+ //
+ // Note: currently out of bounds indices may crash the underlying XLA
+ // machine.
+ BINOP_INDEX = 13;
+
+ // Element-wise maximum.
+ BINOP_MAX = 14;
+
+ // Element-wise minimum.
+ BINOP_MIN = 15;
+
+ // Raises the left-hand-side to the right-hand-side power.
+ BINOP_POW = 16;
+
+ // Remainder operation.
+ BINOP_REM = 17;
+
+ // Logical operators
+ BINOP_LOGICAL_AND = 18;
+ BINOP_LOGICAL_OR = 19;
+}
+
+message BinaryOpRequest {
+ BinaryOperation binop = 2;
+ ComputationDataHandle lhs = 3;
+ ComputationDataHandle rhs = 4;
+ repeated int64 broadcast_dimensions = 5;
+}
+
+enum RandomDistribution {
+ RNG_INVALID = 0;
+
+ // Creates a uniform-distribution-generated random number on the interval
+ // [parameter[0], parameter[1]].
+ RNG_UNIFORM = 1;
+
+ // Creates a normal-distribution-generated random number with mean
+ // parameter[0] and standard deviation parameter[1].
+ RNG_NORMAL = 2;
+
+ // Creates a Bernoulli-distribution-generated random number with mean
+ // parameter[0].
+ RNG_BERNOULLI = 3;
+}
+
+message RngRequest {
+ RandomDistribution distribution = 2;
+ repeated ComputationDataHandle parameter = 3;
+ Shape shape = 4;
+}
+
+enum TernaryOperation {
+ TRIOP_INVALID = 0;
+
+ // Given a predicate and two operands, selects operand0 if the predicate is
+ // true and operand1 if the predicate is false.
+ TRIOP_SELECT = 1;
+
+ // Updates operand0 at index operand1 with value operand2 and outputs the
+ // updated value.
+ TRIOP_UPDATE = 2;
+
+ // Given a min, max and an operand returns the operand if between min and max,
+ // else returns min if operand is less than min or max if operand is greater
+ // than max.
+ TRIOP_CLAMP = 3;
+}
+
+message TernaryOpRequest {
+ TernaryOperation triop = 2;
+ ComputationDataHandle lhs = 3;
+ ComputationDataHandle rhs = 4;
+ ComputationDataHandle ehs = 5;
+}
+
+enum VariadicOperation {
+ VAROP_INVALID = 0;
+
+ // Creates a tuple from its operands.
+ VAROP_TUPLE = 1;
+}
+
+message VariadicOpRequest {
+ VariadicOperation varop = 2;
+ repeated ComputationDataHandle operands = 3;
+}
+
+message SendRequest {
+ ComputationDataHandle operand = 1;
+ ChannelHandle channel_handle = 2;
+}
+
+message RecvRequest {
+ Shape shape = 1;
+ ChannelHandle channel_handle = 2;
+}
+
+message OpRequest {
+ ComputationHandle computation = 1;
+
+ oneof op {
+ BinaryOpRequest binary_op_request = 2;
+ BroadcastRequest broadcast_request = 3;
+ CallRequest call_request = 4;
+ ConcatenateRequest concatenate_request = 5;
+ ConstantRequest constant_request = 6;
+ ConvertRequest convert_request = 7;
+ ConvolveRequest convolve_request = 8;
+ CrossReplicaSumRequest cross_replica_sum_request = 9;
+ CustomCallRequest custom_call_request = 10;
+ DynamicSliceRequest dynamic_slice_request = 11;
+ DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
+ GetTupleElementRequest get_tuple_element_request = 13;
+ InfeedRequest infeed_request = 14;
+ MapRequest map_request = 15;
+ PadRequest pad_request = 16;
+ ParameterRequest parameter_request = 17;
+ ReduceRequest reduce_request = 18;
+ ReduceWindowRequest reduce_window_request = 19;
+ ReshapeRequest reshape_request = 20;
+ ReverseRequest reverse_request = 21;
+ RngRequest rng_request = 22;
+ SelectAndScatterRequest select_and_scatter_request = 23;
+ SliceRequest slice_request = 24;
+ TernaryOpRequest ternary_op_request = 25;
+ TraceRequest trace_request = 26;
+ UnaryOpRequest unary_op_request = 27;
+ VariadicOpRequest variadic_op_request = 28;
+ WhileRequest while_request = 29;
+ SendRequest send_request = 30;
+ RecvRequest recv_request = 31;
+ // Next: 32
+ }
+}
+
+message OpResponse {
+ ComputationDataHandle output = 1;
+}
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index b63ec64970..bc3b60ef28 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -4,6 +4,7 @@ package(default_visibility = [":friends"])
package_group(
name = "friends",
+ includes = ["//tensorflow/compiler/jit:friends"],
packages = ["//tensorflow/..."],
)
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index abad729d8b..80c23b1df1 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -6,6 +6,7 @@ load("@protobuf//:protobuf.bzl", "py_proto_library")
# configure may change the following lines to True
WITH_GCP_SUPPORT = False
WITH_HDFS_SUPPORT = False
+WITH_XLA_SUPPORT = False
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
@@ -184,4 +185,13 @@ def tf_additional_lib_deps():
return deps
def tf_additional_plugin_deps():
- return []
+ deps = []
+ if WITH_XLA_SUPPORT:
+ deps.append("//tensorflow/compiler/jit")
+ return deps
+
+def tf_additional_license_deps():
+ licenses = []
+ if WITH_XLA_SUPPORT:
+ licenses.append("@llvm//:LICENSE.TXT")
+ return licenses
diff --git a/tensorflow/tools/ci_build/builds/configured b/tensorflow/tools/ci_build/builds/configured
index 2776942593..f4c3ae1077 100755
--- a/tensorflow/tools/ci_build/builds/configured
+++ b/tensorflow/tools/ci_build/builds/configured
@@ -33,6 +33,10 @@ export TF_NEED_GCP=1
# Enable support for HDFS
export TF_NEED_HDFS=1
+# Enable XLA support
+# export TF_ENABLE_XLA=${TF_ENABLE_XLA:-0}
+export TF_ENABLE_XLA=1
+
if [[ "$1" == "--disable-gcp" ]]; then
export TF_NEED_GCP=0
shift 1
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index b99dbe954e..5570cea32f 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -4,6 +4,7 @@
package(default_visibility = ["//visibility:private"])
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_license_deps")
# This returns a list of headers of all public header libraries (e.g.,
# framework, lib), and all of the transitive dependencies of those
@@ -94,7 +95,7 @@ filegroup(
"@protobuf//:LICENSE",
"@six_archive//:LICENSE",
"@zlib_archive//:zlib.h",
- ],
+ ] + tf_additional_license_deps(),
)
sh_binary(
diff --git a/third_party/llvm/llvm.BUILD b/third_party/llvm/llvm.BUILD
index 03ca850f96..0f7ef74545 100644
--- a/third_party/llvm/llvm.BUILD
+++ b/third_party/llvm/llvm.BUILD
@@ -1584,6 +1584,7 @@ cc_library(
"include/llvm/ExecutionEngine/RTDyldMemoryManager.h",
"lib/ExecutionEngine/RuntimeDyld/*.h",
"lib/ExecutionEngine/RuntimeDyld/Targets/*.h",
+ "lib/ExecutionEngine/RuntimeDyld/Targets/*.cpp",
"lib/ExecutionEngine/RuntimeDyld/*.h",
]),
hdrs = glob([